mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 48efca5f34 | |||
| 44f1bb022b | |||
| 95c2a0285c | |||
| 8ea4813936 | |||
| 650bf8c0c0 | |||
| dd4bb5d362 | |||
| b3137e7ae8 | |||
| 0d740925c5 | |||
| 030ff59c40 | |||
| 3bc56740b9 | |||
| 87a5762bf2 | |||
| 5b149927fb | |||
| 275bb689e4 | |||
| d595945309 | |||
| e8ba741f28 | |||
| 57929b73ee | |||
| 9fd4e42438 | |||
| 0176cd9669 | |||
| 2e9b29970e | |||
| a07fc1510b | |||
| 78f067acef | |||
| 0d5f97c8c7 | |||
| a987387353 | |||
| 941b6dd340 | |||
| 23251a2487 | |||
| 6af8301996 | |||
| b57ec9eda1 | |||
| c8594a2eaa | |||
| 102715a3c9 | |||
| d5e66b4f3a | |||
| f5315887ec | |||
| 6e4ac8e2ce |
+1
-1
@@ -161,7 +161,7 @@ poetry run pytest ./tests/unit/test_*.py
|
||||
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
|
||||
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
|
||||
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.1-nikolaik`
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.2-nikolaik`
|
||||
|
||||
## Develop inside Docker container
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@ services:
|
||||
- SANDBOX_API_HOSTNAME=host.docker.internal
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.1-nikolaik}
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/runtime}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.2-nikolaik}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
+2
-1
@@ -7,7 +7,8 @@ services:
|
||||
image: openhands:latest
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.1-nikolaik}
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-docker.openhands.dev/openhands/runtime}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.2-nikolaik}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
+1
-1
@@ -2,7 +2,7 @@ BACKEND_HOST ?= "127.0.0.1"
|
||||
BACKEND_PORT = 3000
|
||||
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
|
||||
FRONTEND_PORT = 3001
|
||||
OPENHANDS_PATH ?= "../../OpenHands"
|
||||
OPENHANDS_PATH ?= ".."
|
||||
OPENHANDS := $(OPENHANDS_PATH)
|
||||
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build
|
||||
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
# Organization Code Review Bot - PRD
|
||||
|
||||
## Problem Statement
|
||||
|
||||
### Engineering Leader Perspective
|
||||
|
||||
AI coding agents have made code generation cheap, but verification remains the bottleneck. Engineering leaders lack visibility into:
|
||||
- **Review quality**: How effective are AI-generated code reviews vs. human reviews?
|
||||
- **Developer engagement**: Are developers acting on AI review feedback?
|
||||
- **Organizational patterns**: What recurring feedback themes exist across the organization that could be codified?
|
||||
|
||||
Without telemetry, orgs cannot measure ROI of AI-assisted code review or systematically improve their review standards.
|
||||
|
||||
### Developer Perspective
|
||||
|
||||
Current AI code reviews produce itemized feedback, but acting on that feedback is manual and disconnected:
|
||||
- Developers must context-switch to address each review item separately
|
||||
- No one-click path from "review comment" → "fix implementation"
|
||||
- No feedback loop to help the AI learn which suggestions are valuable vs. noise
|
||||
|
||||
## Goals
|
||||
|
||||
1. **One-click remediation**: Enable developers to launch an OpenHands conversation directly from a review item to address it
|
||||
2. **Org-wide telemetry**: Track accept/dismiss rates on review items to measure review quality and surface patterns
|
||||
3. **Learned review standards**: Distill recurring org-specific feedback into a lightweight review standard the bot applies automatically
|
||||
4. **Verification signals**: Integrate code survival metrics (what % of AI-written code survives to merge) to predict review quality
|
||||
|
||||
## User Personas
|
||||
|
||||
| Persona | Description |
|
||||
|---------|-------------|
|
||||
| **Developer** | Uses the bot to get PR reviews and quickly address feedback items via one-click OpenHands sessions |
|
||||
| **Tech Lead** | Reviews org-wide feedback patterns to identify common issues and improve team coding standards |
|
||||
| **Engineering Manager** | Monitors accept/dismiss telemetry to assess AI review effectiveness and developer adoption |
|
||||
| **Platform Engineer** | Configures org-specific review rules and integrates the bot with existing CI/CD workflows |
|
||||
|
||||
## Key Use Cases
|
||||
|
||||
### 1. One-Click Review Item Remediation
|
||||
- Developer receives AI code review with itemized feedback
|
||||
- Each feedback item has a "Fix with OpenHands" button that launches a scoped conversation to address that specific issue
|
||||
- Context (diff, review comment, file) is automatically passed to the agent
|
||||
|
||||
### 2. Accept/Dismiss Feedback Telemetry
|
||||
- Developers can mark review items as "Agree & Fix" or "Dismiss"
|
||||
- Org-wide dashboard shows aggregate accept/dismiss rates per review category
|
||||
- Identifies high-value feedback patterns vs. low-signal noise
|
||||
|
||||
### 3. Org-Specific Review Standards
|
||||
- Platform engineer configures org-specific review rules (e.g., "always check for error handling in API routes")
|
||||
- Bot learns from historical code reviews to surface org-specific patterns
|
||||
- Review standards are versioned and auditable
|
||||
|
||||
### 4. Code Survival Metrics
|
||||
- Track what fraction of AI-suggested changes make it into the merged PR
|
||||
- Surface low-survival patterns to improve review prompt quality
|
||||
- Use survival signals to predict whether a review item is likely to be addressed
|
||||
|
||||
### 5. Review Quality Dashboard
|
||||
- Engineering managers see per-team and per-repo review effectiveness metrics
|
||||
- Trending view of common feedback categories over time
|
||||
- Alerts when review quality drops or dismiss rates spike
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from integrations.store_repo_utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
@@ -78,19 +78,17 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.enable_proactive_conversation_starters is None:
|
||||
# Check global setting first - if disabled globally, return False
|
||||
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||
return False
|
||||
|
||||
return settings.enable_proactive_conversation_starters
|
||||
def _get_setting():
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
return False
|
||||
return bool(org.enable_proactive_conversation_starters)
|
||||
|
||||
return await call_sync_from_async(_get_setting)
|
||||
|
||||
|
||||
# =================================================
|
||||
@@ -151,6 +149,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
issue_body=self.description,
|
||||
previous_comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
@@ -189,6 +188,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
@@ -327,7 +327,6 @@ class GithubIssueComment(GithubIssue):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
@@ -397,7 +396,6 @@ class GithubInlinePRComment(GithubPRComment):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'pr_update_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
pr_number=self.issue_number,
|
||||
pr_title=self.title,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.store_repo_utils import store_repositories_in_db
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from integrations.jira.jira_types import JiraViewInterface
|
||||
from integrations.jira.jira_view import (
|
||||
JiraExistingConversationView,
|
||||
JiraFactory,
|
||||
JiraNewConversationView,
|
||||
)
|
||||
@@ -88,100 +84,53 @@ class JiraManager(Manager):
|
||||
)
|
||||
return repos
|
||||
|
||||
async def validate_request(
|
||||
self, request: Request
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||
"""Verify Jira webhook signature."""
|
||||
signature_header = request.headers.get('x-hub-signature')
|
||||
signature = signature_header.split('=')[1] if signature_header else None
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
workspace_name = ''
|
||||
|
||||
def get_workspace_name_from_payload(self, payload: dict) -> str | None:
|
||||
"""Extract workspace name from Jira webhook payload."""
|
||||
if payload.get('webhookEvent') == 'comment_created':
|
||||
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
|
||||
elif payload.get('webhookEvent') == 'jira:issue_updated':
|
||||
selfUrl = payload.get('user', {}).get('self')
|
||||
else:
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(selfUrl)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not workspace_name:
|
||||
logger.warning('[Jira] No workspace name found in webhook payload')
|
||||
return False, None, None
|
||||
|
||||
if not signature:
|
||||
logger.warning('[Jira] No signature found in webhook headers')
|
||||
return False, None, None
|
||||
|
||||
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if not workspace:
|
||||
logger.warning('[Jira] Could not identify workspace for webhook')
|
||||
return False, None, None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||
return False, None, None
|
||||
|
||||
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
if hmac.compare_digest(signature, digest):
|
||||
logger.info('[Jira] Webhook signature verified successfully')
|
||||
return True, signature, payload
|
||||
|
||||
return False, None, None
|
||||
|
||||
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||
event_type = payload.get('webhookEvent')
|
||||
|
||||
if event_type == 'comment_created':
|
||||
comment_data = payload.get('comment', {})
|
||||
comment = comment_data.get('body', '')
|
||||
|
||||
if '@openhands' not in comment:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = comment_data.get('author', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
elif event_type == 'jira:issue_updated':
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
if 'openhands' not in labels:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = payload.get('user', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
comment = ''
|
||||
else:
|
||||
return None
|
||||
|
||||
workspace_name = ''
|
||||
if not selfUrl:
|
||||
return None
|
||||
|
||||
parsedUrl = urlparse(selfUrl)
|
||||
return parsedUrl.hostname or None
|
||||
|
||||
def parse_webhook(self, message: Message) -> JobContext | None:
|
||||
payload = message.message.get('payload', {})
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
self_url = issue_data.get('self', '')
|
||||
if not self_url:
|
||||
logger.warning('[Jira] Missing self URL in issue data')
|
||||
base_api_url = ''
|
||||
elif '/rest/' in self_url:
|
||||
base_api_url = self_url.split('/rest/')[0]
|
||||
else:
|
||||
# Fallback: extract base URL using urlparse
|
||||
parsed = urlparse(self_url)
|
||||
base_api_url = f'{parsed.scheme}://{parsed.netloc}'
|
||||
|
||||
comment = ''
|
||||
if JiraFactory.is_ticket_comment(message):
|
||||
comment_data = payload.get('comment', {})
|
||||
comment = comment_data.get('body', '')
|
||||
user_data: dict = comment_data.get('author', {})
|
||||
elif JiraFactory.is_labeled_ticket(message):
|
||||
user_data = payload.get('user', {})
|
||||
|
||||
else:
|
||||
raise ValueError('Unrecognized jira event')
|
||||
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
|
||||
workspace_name = ''
|
||||
parsedUrl = urlparse(base_api_url)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
@@ -210,14 +159,28 @@ class JiraManager(Manager):
|
||||
base_api_url=base_api_url,
|
||||
)
|
||||
|
||||
async def is_job_requested(self, message: Message) -> bool:
|
||||
return JiraFactory.is_labeled_ticket(message) or JiraFactory.is_ticket_comment(
|
||||
message
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira webhook message."""
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
job_context = self.parse_webhook(payload)
|
||||
logger.info('[Jira]: received payload', extra={'payload': payload})
|
||||
|
||||
is_job_requested = await self.is_job_requested(message)
|
||||
if not is_job_requested:
|
||||
return
|
||||
|
||||
job_context = self.parse_webhook(message)
|
||||
|
||||
if not job_context:
|
||||
logger.info('[Jira] Webhook does not match trigger conditions')
|
||||
logger.info(
|
||||
'[Jira] Failed to parse webhook payload - missing required fields or invalid structure',
|
||||
extra={'event_type': payload.get('webhookEvent')},
|
||||
)
|
||||
return
|
||||
|
||||
# Get workspace by user email domain
|
||||
@@ -297,21 +260,18 @@ class JiraManager(Manager):
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, jira_view):
|
||||
if not await self.is_repository_specified(message, jira_view):
|
||||
return
|
||||
|
||||
await self.start_job(jira_view)
|
||||
|
||||
async def is_job_requested(
|
||||
async def is_repository_specified(
|
||||
self, message: Message, jira_view: JiraViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
if isinstance(jira_view, JiraExistingConversationView):
|
||||
return True
|
||||
|
||||
try:
|
||||
# Get user repositories
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
@@ -334,7 +294,7 @@ class JiraManager(Manager):
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Error in is_job_requested: {str(e)}')
|
||||
logger.error(f'[Jira] Error determining repository: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, jira_view: JiraViewInterface):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.jira.jira_types import JiraViewInterface, StartingConvoException
|
||||
from integrations.models import JobContext
|
||||
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||
from integrations.models import JobContext, Message
|
||||
from integrations.utils import CONVERSATION_URL, HOST, get_oh_labels, has_exact_mention
|
||||
from jinja2 import Environment
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
@@ -10,19 +10,16 @@ from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
integration_store = JiraIntegrationStore.get_instance()
|
||||
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraNewConversationView(JiraViewInterface):
|
||||
@@ -101,90 +98,39 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraExistingConversationView(JiraViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
)
|
||||
|
||||
return '', user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Jira conversation"""
|
||||
|
||||
user_id = self.jira_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
|
||||
try:
|
||||
await conversation_store.get_metadata(self.conversation_id)
|
||||
except FileNotFoundError:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
# Should we raise here if there are no providers?
|
||||
providers_set = list(provider_tokens.keys()) if provider_tokens else []
|
||||
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, self.conversation_id, providers_set
|
||||
)
|
||||
|
||||
# Either join ongoing conversation, or restart the conversation
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
agent_state = (
|
||||
None
|
||||
if len(final_agent_observation) == 0
|
||||
else final_agent_observation[0].agent_state
|
||||
)
|
||||
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
class JiraFactory:
|
||||
"""Factory for creating Jira views based on message content"""
|
||||
|
||||
@staticmethod
|
||||
def is_labeled_ticket(message: Message) -> bool:
|
||||
payload = message.message.get('payload', {})
|
||||
event_type = payload.get('webhookEvent')
|
||||
|
||||
if event_type != 'jira:issue_updated':
|
||||
return False
|
||||
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
return OH_LABEL in labels
|
||||
|
||||
@staticmethod
|
||||
def is_ticket_comment(message: Message) -> bool:
|
||||
payload = message.message.get('payload', {})
|
||||
event_type = payload.get('webhookEvent')
|
||||
|
||||
if event_type != 'comment_created':
|
||||
return False
|
||||
|
||||
comment_data = payload.get('comment', {})
|
||||
comment_body = comment_data.get('body', '')
|
||||
return has_exact_mention(comment_body, INLINE_OH_LABEL)
|
||||
|
||||
@staticmethod
|
||||
async def create_jira_view_from_payload(
|
||||
job_context: JobContext,
|
||||
@@ -197,23 +143,6 @@ class JiraFactory:
|
||||
if not jira_user or not saas_user_auth or not jira_workspace:
|
||||
raise StartingConvoException('User not authenticated with Jira integration')
|
||||
|
||||
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||
job_context.issue_id, jira_user.id
|
||||
)
|
||||
|
||||
if conversation:
|
||||
logger.info(
|
||||
f'[Jira] Found existing conversation for issue {job_context.issue_id}'
|
||||
)
|
||||
return JiraExistingConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_user=jira_user,
|
||||
jira_workspace=jira_workspace,
|
||||
selected_repo=None,
|
||||
conversation_id=conversation.conversation_id,
|
||||
)
|
||||
|
||||
return JiraNewConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
|
||||
@@ -29,17 +29,20 @@ class ResolverUserContext(UserContext):
|
||||
|
||||
return UserInfo(id=user_id)
|
||||
|
||||
async def get_authenticated_git_url(self, repository: str) -> str:
|
||||
async def get_authenticated_git_url(
|
||||
self, repository: str, is_optional: bool = False
|
||||
) -> str:
|
||||
# This would need to be implemented based on the git provider tokens
|
||||
# For now, return a basic HTTPS URL
|
||||
return f'https://github.com/{repository}.git'
|
||||
|
||||
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
||||
# Return the appropriate token from git_provider_tokens
|
||||
|
||||
# Return the appropriate token string from git_provider_tokens
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens:
|
||||
return provider_tokens.get(provider_type)
|
||||
provider_token = provider_tokens.get(provider_type)
|
||||
if provider_token and provider_token.token:
|
||||
return provider_token.token.get_secret_value()
|
||||
return None
|
||||
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
|
||||
@@ -189,6 +189,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
'channel_id': self.channel_id,
|
||||
'conversation_id': self.conversation_id,
|
||||
'keycloak_user_id': user_info.keycloak_user_id,
|
||||
'org_id': user_info.org_id,
|
||||
'parent_id': self.thread_ts or self.message_ts,
|
||||
'v1_enabled': v1_enabled,
|
||||
},
|
||||
@@ -197,6 +198,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
conversation_id=self.conversation_id,
|
||||
channel_id=self.channel_id,
|
||||
keycloak_user_id=user_info.keycloak_user_id,
|
||||
org_id=user_info.org_id,
|
||||
parent_id=self.thread_ts
|
||||
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
|
||||
v1_enabled=v1_enabled,
|
||||
@@ -399,10 +401,10 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg_action = MessageAction(content=user_msg)
|
||||
instructions, _ = self._get_instructions(jinja)
|
||||
user_msg = MessageAction(content=instructions)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg_action)
|
||||
self.conversation_id, event_to_dict(user_msg)
|
||||
)
|
||||
|
||||
async def send_message_to_v1_conversation(self, jinja: Environment):
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
from storage.repository_store import RepositoryStore
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_repo_map_store import UserRepositoryMapStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import Repository
|
||||
|
||||
|
||||
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||
"""
|
||||
Store repositories in DB and create user-repository mappings
|
||||
|
||||
Args:
|
||||
repos: List of Repository objects to store
|
||||
user_id: User ID associated with these repositories
|
||||
"""
|
||||
|
||||
# Convert Repository objects to StoredRepository objects
|
||||
# Convert Repository objects to UserRepositoryMap objects
|
||||
stored_repos = []
|
||||
user_repos = []
|
||||
for repo in repos:
|
||||
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
|
||||
stored_repo = StoredRepository(
|
||||
repo_name=repo.full_name,
|
||||
repo_id=repo_id,
|
||||
is_public=repo.is_public,
|
||||
# Optional fields set to None by default
|
||||
has_microagent=None,
|
||||
has_setup_script=None,
|
||||
)
|
||||
stored_repos.append(stored_repo)
|
||||
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
|
||||
|
||||
user_repos.append(user_repo_map)
|
||||
|
||||
# Get config instance
|
||||
config = OpenHandsConfig()
|
||||
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
logger.warning('Failed to save repos', exc_info=True)
|
||||
@@ -1,19 +1,24 @@
|
||||
from uuid import UUID
|
||||
|
||||
import stripe
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.filter(StripeCustomer.org_id == org_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer:
|
||||
@@ -21,46 +26,76 @@ async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
|
||||
# If that fails, fallback to stripe
|
||||
search_result = await stripe.Customer.search_async(
|
||||
query=f"metadata['user_id']:'{user_id}'",
|
||||
query=f"metadata['org_id']:'{str(org_id)}'",
|
||||
)
|
||||
data = search_result.data
|
||||
if not data:
|
||||
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
|
||||
logger.info(
|
||||
'no_customer_for_org_id',
|
||||
extra={'org_id': str(org_id)},
|
||||
)
|
||||
return None
|
||||
return data[0].id # type: ignore [attr-defined]
|
||||
|
||||
|
||||
async def find_or_create_customer(user_id: str) -> str:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id:
|
||||
return customer_id
|
||||
logger.info('creating_customer', extra={'user_id': user_id})
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
return customer_id
|
||||
|
||||
# Get the user info from keycloak
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
|
||||
|
||||
async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
# Get the current org for the user
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
if customer_id:
|
||||
return {'customer_id': customer_id, 'org_id': str(org.id)}
|
||||
logger.info(
|
||||
'creating_customer',
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
# Create the customer in stripe
|
||||
customer = await stripe.Customer.create_async(
|
||||
email=str(user_info.get('email', '')),
|
||||
metadata={'user_id': user_id},
|
||||
email=org.contact_email,
|
||||
metadata={'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
|
||||
StripeCustomer(
|
||||
keycloak_user_id=user_id,
|
||||
org_id=org.id,
|
||||
stripe_customer_id=customer.id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
return customer.id
|
||||
return {'customer_id': customer.id, 'org_id': str(org.id)}
|
||||
|
||||
|
||||
async def has_payment_method(user_id: str) -> bool:
|
||||
async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id is None:
|
||||
return False
|
||||
@@ -71,3 +106,28 @@ async def has_payment_method(user_id: str) -> bool:
|
||||
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
|
||||
)
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(session: Session, user_id: str, org: Org):
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -6,15 +6,9 @@ import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.config import get_config
|
||||
from server.constants import WEB_HOST
|
||||
from storage.database import session_maker
|
||||
from storage.repository_store import RepositoryStore
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_repo_map_store import UserRepositoryMapStore
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events import Event, EventSource
|
||||
@@ -125,26 +119,17 @@ async def get_user_v1_enabled_setting(user_id: str | None) -> bool:
|
||||
Returns:
|
||||
True if V1 conversations are enabled for this user, False otherwise
|
||||
"""
|
||||
|
||||
# If no user ID is provided, we can't check user settings
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.v1_enabled is None:
|
||||
if not org or org.v1_enabled is None:
|
||||
return False
|
||||
|
||||
return settings.v1_enabled
|
||||
return org.v1_enabled
|
||||
|
||||
|
||||
def has_exact_mention(text: str, mention: str) -> bool:
|
||||
@@ -409,51 +394,6 @@ def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||
return message + footer
|
||||
|
||||
|
||||
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||
"""
|
||||
Store repositories in DB and create user-repository mappings
|
||||
|
||||
Args:
|
||||
repos: List of Repository objects to store
|
||||
user_id: User ID associated with these repositories
|
||||
"""
|
||||
|
||||
# Convert Repository objects to StoredRepository objects
|
||||
# Convert Repository objects to UserRepositoryMap objects
|
||||
stored_repos = []
|
||||
user_repos = []
|
||||
for repo in repos:
|
||||
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
|
||||
stored_repo = StoredRepository(
|
||||
repo_name=repo.full_name,
|
||||
repo_id=repo_id,
|
||||
is_public=repo.is_public,
|
||||
# Optional fields set to None by default
|
||||
has_microagent=None,
|
||||
has_setup_script=None,
|
||||
)
|
||||
stored_repos.append(stored_repo)
|
||||
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
|
||||
|
||||
user_repos.append(user_repo_map)
|
||||
|
||||
# Get config instance
|
||||
config = OpenHandsConfig()
|
||||
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
logger.warning('Failed to save repos', exc_info=True)
|
||||
|
||||
|
||||
def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
"""
|
||||
Extract all repository names in the format 'owner/repo' from various Git provider URLs
|
||||
|
||||
@@ -20,6 +20,8 @@ down_revision = '059'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# TODO: decide whether to modify this for orgs or users
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""
|
||||
@@ -28,8 +30,10 @@ def upgrade():
|
||||
|
||||
This replaces the functionality of the removed admin maintenance endpoint.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
|
||||
# Hardcoded value to prevent migration failures when constant is removed from codebase
|
||||
# This migration has already run in production, so we use the value that was current at the time
|
||||
CURRENT_USER_SETTINGS_VERSION = 4
|
||||
|
||||
# Create a connection and bind it to a session
|
||||
connection = op.get_bind()
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
"""create org tables from pgerd schema
|
||||
|
||||
Revision ID: 089
|
||||
Revises: 088
|
||||
Create Date: 2025-01-07 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '089'
|
||||
down_revision: Union[str, None] = '088'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove current settings table
|
||||
op.execute('DROP TABLE IF EXISTS settings')
|
||||
|
||||
# Add already_migrated column to user_settings table
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column(
|
||||
'already_migrated',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
# Create role table
|
||||
op.create_table(
|
||||
'role',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('rank', sa.Integer, nullable=False),
|
||||
sa.UniqueConstraint('name', name='role_name_unique'),
|
||||
)
|
||||
|
||||
# 1. Create default roles
|
||||
op.execute(
|
||||
sa.text("""
|
||||
INSERT INTO role (name, rank) VALUES ('owner', 10), ('admin', 20), ('user', 1000)
|
||||
ON CONFLICT (name) DO NOTHING;
|
||||
""")
|
||||
)
|
||||
|
||||
# Create org table with settings fields
|
||||
op.create_table(
|
||||
'org',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('contact_name', sa.String, nullable=True),
|
||||
sa.Column('contact_email', sa.String, nullable=True),
|
||||
sa.Column('conversation_expiration', sa.Integer, nullable=True),
|
||||
# Settings fields moved to org table
|
||||
sa.Column('agent', sa.String, nullable=True),
|
||||
sa.Column('default_max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('security_analyzer', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'confirmation_mode',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('default_llm_model', sa.String, nullable=True),
|
||||
sa.Column('default_llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('remote_runtime_resource_factor', sa.Integer, nullable=True),
|
||||
sa.Column(
|
||||
'enable_default_condenser',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('billing_margin', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_proactive_conversation_starters',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('sandbox_base_container_image', sa.String, nullable=True),
|
||||
sa.Column('sandbox_runtime_container_image', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'org_version', sa.Integer, nullable=False, server_default=sa.text('0')
|
||||
),
|
||||
sa.Column('mcp_config', sa.JSON, nullable=True),
|
||||
sa.Column('_search_api_key', sa.String, nullable=True),
|
||||
sa.Column('_sandbox_api_key', sa.String, nullable=True),
|
||||
sa.Column('max_budget_per_task', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_solvability_analysis',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('v1_enabled', sa.Boolean, nullable=True),
|
||||
sa.Column('condenser_max_size', sa.Integer, nullable=True),
|
||||
sa.UniqueConstraint('name', name='org_name_unique'),
|
||||
)
|
||||
|
||||
# Create user table with user-specific settings fields
|
||||
op.create_table(
|
||||
'user',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('current_org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=True),
|
||||
sa.Column('accepted_tos', sa.DateTime, nullable=True),
|
||||
sa.Column(
|
||||
'enable_sound_notifications',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('language', sa.String, nullable=True),
|
||||
sa.Column('user_consents_to_analytics', sa.Boolean, nullable=True),
|
||||
sa.Column('email', sa.String, nullable=True),
|
||||
sa.Column('email_verified', sa.Boolean, nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
['current_org_id'], ['org.id'], name='current_org_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='user_role_fkey'),
|
||||
)
|
||||
|
||||
# Create org_member table (junction table for many-to-many relationship)
|
||||
op.create_table(
|
||||
'org_member',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('_llm_api_key', sa.String, nullable=False),
|
||||
sa.Column('max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('llm_model', sa.String, nullable=True),
|
||||
sa.Column('_llm_api_key_for_byor', sa.String, nullable=True),
|
||||
sa.Column('llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('status', sa.String, nullable=True),
|
||||
sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='om_org_fkey'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='om_user_fkey'),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='om_role_fkey'),
|
||||
sa.PrimaryKeyConstraint('org_id', 'user_id'),
|
||||
)
|
||||
|
||||
# Add org_id column to existing tables
|
||||
# billing_sessions
|
||||
op.add_column(
|
||||
'billing_sessions',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# Create conversation_metadata_saas table
|
||||
op.create_table(
|
||||
'conversation_metadata_saas',
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'
|
||||
),
|
||||
sa.PrimaryKeyConstraint('conversation_id'),
|
||||
)
|
||||
|
||||
# custom_secrets
|
||||
op.add_column(
|
||||
'custom_secrets',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'custom_secrets_org_fkey', 'custom_secrets', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# api_keys
|
||||
op.add_column(
|
||||
'api_keys', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key('api_keys_org_fkey', 'api_keys', 'org', ['org_id'], ['id'])
|
||||
|
||||
# slack_conversation
|
||||
op.add_column(
|
||||
'slack_conversation',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# slack_users
|
||||
op.add_column(
|
||||
'slack_users', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_users_org_fkey', 'slack_users', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# stripe_customers
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
op.add_column(
|
||||
'stripe_customers',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop already_migrated column from user_settings table
|
||||
op.drop_column('user_settings', 'already_migrated')
|
||||
|
||||
# Drop foreign keys and columns added to existing tables
|
||||
op.drop_constraint(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('stripe_customers', 'org_id')
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.drop_constraint('slack_users_org_fkey', 'slack_users', type_='foreignkey')
|
||||
op.drop_column('slack_users', 'org_id')
|
||||
|
||||
op.drop_constraint(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('slack_conversation', 'org_id')
|
||||
|
||||
op.drop_constraint('api_keys_org_fkey', 'api_keys', type_='foreignkey')
|
||||
op.drop_column('api_keys', 'org_id')
|
||||
|
||||
op.drop_constraint('custom_secrets_org_fkey', 'custom_secrets', type_='foreignkey')
|
||||
op.drop_column('custom_secrets', 'org_id')
|
||||
|
||||
# Drop conversation_metadata_saas table
|
||||
op.drop_table('conversation_metadata_saas')
|
||||
|
||||
op.drop_constraint(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('billing_sessions', 'org_id')
|
||||
|
||||
# Drop tables in reverse order due to foreign key constraints
|
||||
op.drop_table('org_member')
|
||||
op.drop_table('user')
|
||||
op.drop_table('org')
|
||||
op.drop_table('role')
|
||||
Generated
+6001
-5599
File diff suppressed because one or more lines are too long
@@ -4,6 +4,10 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Ensure SAAS configuration is used
|
||||
if not os.getenv('OPENHANDS_CONFIG_CLS'):
|
||||
os.environ['OPENHANDS_CONFIG_CLS'] = 'server.config.SaaSServerConfig'
|
||||
|
||||
import socketio # noqa: E402
|
||||
from fastapi import Request, status # noqa: E402
|
||||
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
|
||||
@@ -34,6 +38,7 @@ from server.routes.integration.linear import linear_integration_router # noqa:
|
||||
from server.routes.integration.slack import slack_router # noqa: E402
|
||||
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
@@ -86,6 +91,7 @@ if GITLAB_APP_CLIENT_ID:
|
||||
base_app.include_router(gitlab_integration_router)
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
base_app.include_router(org_router) # Add routes for organization management
|
||||
add_github_proxy_routes(base_app)
|
||||
add_debugging_routes(
|
||||
base_app
|
||||
|
||||
@@ -39,6 +39,8 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
|
||||
'on',
|
||||
)
|
||||
|
||||
DUPLICATE_EMAIL_CHECK = os.getenv('DUPLICATE_EMAIL_CHECK', 'true') in ('1', 'true')
|
||||
|
||||
# reCAPTCHA Enterprise
|
||||
RECAPTCHA_PROJECT_ID = os.getenv('RECAPTCHA_PROJECT_ID', '').strip()
|
||||
RECAPTCHA_SITE_KEY = os.getenv('RECAPTCHA_SITE_KEY', '').strip()
|
||||
|
||||
@@ -77,6 +77,15 @@ class SaasUserAuth(UserAuth):
|
||||
self.access_token = SecretStr(tokens['access_token'])
|
||||
self.refresh_token = SecretStr(tokens['refresh_token'])
|
||||
self.refreshed = True
|
||||
if not self.email or not self.email_verified or not self.user_id:
|
||||
# We don't need to verify the signature here because we just refreshed
|
||||
# this token from the IDP via token_manager.refresh()
|
||||
access_token_payload = jwt.decode(
|
||||
tokens['access_token'], options={'verify_signature': False}
|
||||
)
|
||||
self.user_id = access_token_payload['sub']
|
||||
self.email = access_token_payload['email']
|
||||
self.email_verified = access_token_payload['email_verified']
|
||||
|
||||
def _is_token_expired(self, token: SecretStr):
|
||||
logger.debug('saas_user_auth_is_token_expired')
|
||||
@@ -103,7 +112,6 @@ class SaasUserAuth(UserAuth):
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
settings = await settings_store.load()
|
||||
# If load() returned None, should settings be created?
|
||||
if settings:
|
||||
settings.email = self.email
|
||||
settings.email_verified = self.email_verified
|
||||
@@ -274,11 +282,13 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
if not user_id:
|
||||
return None
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
return SaasUserAuth(
|
||||
saas_user_auth = SaasUserAuth(
|
||||
user_id=user_id,
|
||||
refresh_token=SecretStr(offline_token),
|
||||
auth_type=AuthType.BEARER,
|
||||
)
|
||||
await saas_user_auth.refresh()
|
||||
return saas_user_auth
|
||||
except Exception as exc:
|
||||
raise BearerTokenError from exc
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from keycloak.exceptions import (
|
||||
from server.auth.constants import (
|
||||
BITBUCKET_APP_CLIENT_ID,
|
||||
BITBUCKET_APP_CLIENT_SECRET,
|
||||
DUPLICATE_EMAIL_CHECK,
|
||||
GITHUB_APP_CLIENT_ID,
|
||||
GITHUB_APP_CLIENT_SECRET,
|
||||
GITLAB_APP_CLIENT_ID,
|
||||
@@ -646,6 +647,10 @@ class TokenManager:
|
||||
if not email:
|
||||
return False
|
||||
|
||||
# We have the option to skip the duplicate email check in test environments
|
||||
if not DUPLICATE_EMAIL_CHECK:
|
||||
return False
|
||||
|
||||
base_email = extract_base_email(email)
|
||||
if not base_email:
|
||||
logger.warning(f'Could not extract base email from: {email}')
|
||||
|
||||
@@ -8,7 +8,7 @@ import socketio
|
||||
from server.logger import logger
|
||||
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -524,16 +524,18 @@ class ClusteredConversationManager(StandaloneConversationManager):
|
||||
)
|
||||
# Look up the user_id from the database
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_id = (
|
||||
conversation_metadata.user_id if conversation_metadata else None
|
||||
str(conversation_metadata_saas.user_id)
|
||||
if conversation_metadata_saas
|
||||
else None
|
||||
)
|
||||
# Handle the stopped conversation asynchronously
|
||||
asyncio.create_task(
|
||||
|
||||
@@ -19,8 +19,8 @@ IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
|
||||
|
||||
# Map of user settings versions to their corresponding default LLM models
|
||||
# This ensures that CURRENT_USER_SETTINGS_VERSION and LITELLM_DEFAULT_MODEL stay in sync
|
||||
USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
# This ensures that PERSONAL_WORKSPACE_VERSION_TO_MODEL and LITELLM_DEFAULT_MODEL stay in sync
|
||||
PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
1: 'claude-3-5-sonnet-20241022',
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
@@ -31,7 +31,8 @@ USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
|
||||
# Current user settings version - this should be the latest key in USER_SETTINGS_VERSION_TO_MODEL
|
||||
CURRENT_USER_SETTINGS_VERSION = max(USER_SETTINGS_VERSION_TO_MODEL.keys())
|
||||
ORG_SETTINGS_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
PERSONAL_WORKSPACE_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
|
||||
LITE_LLM_API_URL = os.environ.get(
|
||||
'LITE_LLM_API_URL', 'https://llm-proxy.app.all-hands.dev'
|
||||
@@ -55,7 +56,6 @@ SUBSCRIPTION_PRICE_DATA = {
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
STRIPE_WEBHOOK_SECRET = os.environ.get('STRIPE_WEBHOOK_SECRET', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
SLACK_CLIENT_ID = os.environ.get('SLACK_CLIENT_ID', None)
|
||||
@@ -91,5 +91,5 @@ def get_default_litellm_model():
|
||||
"""Construct proxy for litellm model based on user settings if not set explicitly."""
|
||||
if LITELLM_DEFAULT_MODEL:
|
||||
return LITELLM_DEFAULT_MODEL
|
||||
model = USER_SETTINGS_VERSION_TO_MODEL[CURRENT_USER_SETTINGS_VERSION]
|
||||
model = PERSONAL_WORKSPACE_VERSION_TO_MODEL[PERSONAL_WORKSPACE_VERSION]
|
||||
return build_litellm_proxy_model_path(model)
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Email domain validation utilities for enterprise endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_auth, get_user_id
|
||||
|
||||
|
||||
async def get_admin_user_id(
|
||||
request: Request, user_id: str | None = Depends(get_user_id)
|
||||
) -> str:
|
||||
"""
|
||||
Dependency that validates user has @openhands.dev email domain.
|
||||
|
||||
This dependency can be used in place of get_user_id for endpoints that
|
||||
should only be accessible to admin users. Currently, this is implemented
|
||||
by checking for @openhands.dev email domain.
|
||||
|
||||
TODO: In the future, this should be replaced with an explicit is_admin flag
|
||||
in user/org settings instead of relying on email domain validation.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
user_id: User ID from get_user_id dependency
|
||||
|
||||
Returns:
|
||||
str: User ID if email domain is valid
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if email domain is not @openhands.dev
|
||||
HTTPException: 401 if user is not authenticated
|
||||
|
||||
Example:
|
||||
@router.post('/endpoint')
|
||||
async def create_resource(
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
):
|
||||
# Only admin users can access this endpoint
|
||||
pass
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
user_auth = await get_user_auth(request)
|
||||
user_email = await user_auth.get_user_email()
|
||||
|
||||
if not user_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User email not available',
|
||||
)
|
||||
|
||||
if not user_email.endswith('@openhands.dev'):
|
||||
logger.warning(
|
||||
'Access denied - invalid email domain',
|
||||
extra={'user_id': user_id, 'email_domain': user_email.split('@')[-1]},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='Access restricted to @openhands.dev users',
|
||||
)
|
||||
|
||||
return user_id
|
||||
@@ -44,11 +44,13 @@ class MyProcessor(MaintenanceTaskProcessor):
|
||||
### UserVersionUpgradeProcessor
|
||||
|
||||
Located in `user_version_upgrade_processor.py`, this processor:
|
||||
|
||||
- Handles up to 100 user IDs per task
|
||||
- Upgrades users with `user_version < CURRENT_USER_SETTINGS_VERSION`
|
||||
- Upgrades users with `user_version < ORG_SETTINGS_VERSION`
|
||||
- Uses `SaasSettingsStore.create_default_settings()` for upgrades
|
||||
|
||||
**Usage:**
|
||||
|
||||
```python
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import UserVersionUpgradeProcessor
|
||||
|
||||
@@ -144,22 +146,26 @@ task = create_maintenance_task(
|
||||
## Best Practices
|
||||
|
||||
### Processor Design
|
||||
|
||||
- Keep tasks short-running (under 1 minute)
|
||||
- Handle errors gracefully and return meaningful error information
|
||||
- Use batch processing for large datasets
|
||||
- Include progress information in the return dict
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Always wrap your processor logic in try-catch blocks
|
||||
- Return structured error information
|
||||
- Log important events for debugging
|
||||
|
||||
### Performance
|
||||
|
||||
- Limit batch sizes to avoid long-running tasks
|
||||
- Use database sessions efficiently
|
||||
- Consider memory usage for large datasets
|
||||
|
||||
### Testing
|
||||
|
||||
- Create unit tests for your processors
|
||||
- Test error conditions
|
||||
- Verify the processor serialization/deserialization works correctly
|
||||
@@ -167,6 +173,7 @@ task = create_maintenance_task(
|
||||
## Database Patterns
|
||||
|
||||
The maintenance task system follows the repository's established patterns:
|
||||
|
||||
- Uses `session_maker()` for database operations
|
||||
- Wraps sync database operations in `call_sync_from_async` for async routes
|
||||
- Follows proper SQLAlchemy query patterns
|
||||
@@ -174,15 +181,18 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Integration with Existing Systems
|
||||
|
||||
### User Management
|
||||
|
||||
- Integrates with the existing `UserSettings` model
|
||||
- Uses the current user versioning system (`CURRENT_USER_SETTINGS_VERSION`)
|
||||
- Uses the current user versioning system (`ORG_SETTINGS_VERSION`)
|
||||
- Maintains compatibility with existing user management workflows
|
||||
|
||||
### Authentication
|
||||
|
||||
- Admin endpoints use the existing SaaS authentication system
|
||||
- Requires users to have `admin = True` in their UserSettings
|
||||
|
||||
### Monitoring
|
||||
|
||||
- Tasks are logged with structured information
|
||||
- Status updates are tracked in the database
|
||||
- Error information is preserved for debugging
|
||||
@@ -206,6 +216,7 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements that could be added:
|
||||
|
||||
- Task dependencies and scheduling
|
||||
- Retry mechanisms for failed tasks
|
||||
- Real-time progress updates
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.logger import logger
|
||||
from storage.database import session_maker
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskProcessor
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
|
||||
|
||||
class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
|
||||
"""
|
||||
Processor for upgrading user settings to the current version.
|
||||
|
||||
This processor takes a list of user IDs and upgrades any users
|
||||
whose user_version is less than CURRENT_USER_SETTINGS_VERSION.
|
||||
"""
|
||||
|
||||
user_ids: List[str]
|
||||
|
||||
async def __call__(self, task: MaintenanceTask) -> dict:
|
||||
"""
|
||||
Process user version upgrades for the specified user IDs.
|
||||
|
||||
Args:
|
||||
task: The maintenance task being processed
|
||||
|
||||
Returns:
|
||||
dict: Results containing successful and failed user IDs
|
||||
"""
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:start',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_count': len(self.user_ids),
|
||||
'current_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
if len(self.user_ids) > 100:
|
||||
raise ValueError(
|
||||
f'Too many user IDs: {len(self.user_ids)}. Maximum is 100.'
|
||||
)
|
||||
|
||||
config = load_openhands_config()
|
||||
|
||||
# Track results
|
||||
successful_upgrades = []
|
||||
failed_upgrades = []
|
||||
users_already_current = []
|
||||
|
||||
# Find users that need upgrading
|
||||
with session_maker() as session:
|
||||
users_to_upgrade = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id.in_(self.user_ids),
|
||||
UserSettings.user_version < CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Track users that are already current
|
||||
users_needing_upgrade_ids = {u.keycloak_user_id for u in users_to_upgrade}
|
||||
users_already_current = [
|
||||
uid for uid in self.user_ids if uid not in users_needing_upgrade_ids
|
||||
]
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:found_users',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'users_to_upgrade': len(users_to_upgrade),
|
||||
'users_already_current': len(users_already_current),
|
||||
'total_requested': len(self.user_ids),
|
||||
},
|
||||
)
|
||||
|
||||
# Process each user that needs upgrading
|
||||
for user_settings in users_to_upgrade:
|
||||
user_id = user_settings.keycloak_user_id
|
||||
old_version = user_settings.user_version
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:upgrading_user',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
# Create SaasSettingsStore instance and upgrade
|
||||
settings_store = await SaasSettingsStore.get_instance(config, user_id)
|
||||
await settings_store.create_default_settings(user_settings)
|
||||
|
||||
successful_upgrades.append(
|
||||
{
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:user_upgraded',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_upgrades.append(
|
||||
{'user_id': user_id, 'old_version': old_version, 'error': str(e)}
|
||||
)
|
||||
|
||||
logger.error(
|
||||
'user_version_upgrade_processor:user_upgrade_failed',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
# Create result summary
|
||||
result = {
|
||||
'total_users': len(self.user_ids),
|
||||
'users_already_current': users_already_current,
|
||||
'successful_upgrades': successful_upgrades,
|
||||
'failed_upgrades': failed_upgrades,
|
||||
'summary': (
|
||||
f'Processed {len(self.user_ids)} users: '
|
||||
f'{len(successful_upgrades)} upgraded, '
|
||||
f'{len(users_already_current)} already current, '
|
||||
f'{len(failed_upgrades)} errors'
|
||||
),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:completed',
|
||||
extra={'task_id': task.id, 'result': result},
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
|
||||
@@ -36,6 +34,7 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
|
||||
Returns:
|
||||
A tuple containing the default SSE server configuration and a list of MCP stdio server configurations
|
||||
"""
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
if user_id:
|
||||
|
||||
@@ -162,6 +162,8 @@ class SetAuthCookieMiddleware:
|
||||
'/api/email/resend',
|
||||
'/oauth/device/authorize',
|
||||
'/oauth/device/token',
|
||||
'/api/v1/web-client/config',
|
||||
'/api/v1/webhooks/secrets',
|
||||
)
|
||||
if path in ignore_paths:
|
||||
return False
|
||||
|
||||
@@ -1,113 +1,95 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.config import get_config
|
||||
from server.constants import (
|
||||
BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
)
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
# Helper functions for BYOR API key management
|
||||
async def get_byor_key_from_db(user_id: str) -> str | None:
|
||||
"""Get the BYOR key from the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
user_db_settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
if user_db_settings and user_db_settings.llm_api_key_for_byor:
|
||||
return user_db_settings.llm_api_key_for_byor
|
||||
return None
|
||||
def _get_byor_key():
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
if current_org_member.llm_api_key_for_byor:
|
||||
return current_org_member.llm_api_key_for_byor.get_secret_value()
|
||||
return None
|
||||
|
||||
return await call_sync_from_async(_get_byor_key)
|
||||
|
||||
|
||||
async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
"""Store the BYOR key in the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
def _update_user_settings():
|
||||
with session_maker() as session:
|
||||
user_db_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
if user_db_settings:
|
||||
user_db_settings.llm_api_key_for_byor = key
|
||||
session.commit()
|
||||
logger.info(
|
||||
'Successfully stored BYOR key in user settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'User settings not found when trying to store BYOR key',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
current_org_member.llm_api_key_for_byor = key
|
||||
OrgMemberStore.update_org_member(current_org_member)
|
||||
|
||||
await call_sync_from_async(_update_user_settings)
|
||||
|
||||
|
||||
async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
return None
|
||||
current_org_id = str(user.current_org_id)
|
||||
key = await LiteLlmManager.generate_key(
|
||||
user_id,
|
||||
current_org_id,
|
||||
f'BYOR Key - user {user_id}, org {current_org_id}',
|
||||
{'type': 'byor'},
|
||||
)
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'metadata': {'type': 'byor'},
|
||||
'key_alias': f'BYOR Key - user {user_id}',
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...' if key and len(key) > 10 else key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...'
|
||||
if key and len(key) > 10
|
||||
else key,
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error generating BYOR key',
|
||||
@@ -116,96 +98,25 @@ async def generate_byor_key(user_id: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
async def verify_byor_key_in_litellm(byor_key: str, user_id: str) -> bool:
|
||||
"""Verify that a BYOR key is valid in LiteLLM by making a lightweight API call.
|
||||
|
||||
Args:
|
||||
byor_key: The BYOR key to verify
|
||||
user_id: The user ID for logging purposes
|
||||
|
||||
Returns:
|
||||
True if the key is verified as valid, False if verification fails or key is invalid.
|
||||
Returns False on network errors/timeouts to ensure we don't return potentially invalid keys.
|
||||
"""
|
||||
if not (LITE_LLM_API_URL and byor_key):
|
||||
return False
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
) as client:
|
||||
# Make a lightweight request to verify the key
|
||||
# Using /v1/models endpoint as it's lightweight and requires authentication
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/v1/models',
|
||||
headers={
|
||||
'Authorization': f'Bearer {byor_key}',
|
||||
},
|
||||
)
|
||||
|
||||
# Only 200 status code indicates valid key
|
||||
if response.status_code == 200:
|
||||
logger.debug(
|
||||
'BYOR key verification successful',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
|
||||
# All other status codes (401, 403, 500, etc.) are treated as invalid
|
||||
# This includes authentication errors and server errors
|
||||
logger.warning(
|
||||
'BYOR key verification failed - treating as invalid',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': response.status_code,
|
||||
'key_prefix': byor_key[:10] + '...'
|
||||
if len(byor_key) > 10
|
||||
else byor_key,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
except (httpx.TimeoutException, Exception) as e:
|
||||
# Any exception (timeout, network error, etc.) means we can't verify
|
||||
# Return False to trigger regeneration rather than returning potentially invalid key
|
||||
logger.warning(
|
||||
'BYOR key verification error - treating as invalid to ensure key validity',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
'error_type': type(e).__name__,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
|
||||
"""Delete the BYOR key from LiteLLM using the key directly."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return False
|
||||
"""Delete the BYOR key from LiteLLM using the key directly.
|
||||
|
||||
Also attempts to delete by key alias if the key is not found,
|
||||
to clean up orphaned aliases that could block key regeneration.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
# Delete the key directly using the key value
|
||||
delete_url = f'{LITE_LLM_API_URL}/key/delete'
|
||||
delete_payload = {'keys': [byor_key]}
|
||||
# Get user to construct the key alias
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
key_alias = None
|
||||
if user and user.current_org_id:
|
||||
key_alias = f'BYOR Key - user {user_id}, org {user.current_org_id}'
|
||||
|
||||
delete_response = await client.post(delete_url, json=delete_payload)
|
||||
delete_response.raise_for_status()
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
await LiteLlmManager.delete_key(byor_key, key_alias=key_alias)
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error deleting BYOR key from LiteLLM',
|
||||
@@ -357,7 +268,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
byor_key = await get_byor_key_from_db(user_id)
|
||||
if byor_key:
|
||||
# Validate that the key is actually registered in LiteLLM
|
||||
is_valid = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
is_valid = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
if is_valid:
|
||||
return {'key': byor_key}
|
||||
else:
|
||||
@@ -412,15 +323,6 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
|
||||
|
||||
try:
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='LiteLLM API configuration not found',
|
||||
)
|
||||
|
||||
# Get the existing BYOR key from the database
|
||||
existing_byor_key = await get_byor_key_from_db(user_id)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
@@ -22,12 +23,12 @@ from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.recaptcha_service import recaptcha_service
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config, sign_token
|
||||
from server.config import sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
@@ -87,7 +88,8 @@ def get_cookie_domain(request: Request) -> str | None:
|
||||
# for now just use the full hostname except for staging stacks.
|
||||
return (
|
||||
None
|
||||
if (request.url.hostname or '').endswith('staging.all-hand.dev')
|
||||
if not request.url.hostname
|
||||
or request.url.hostname.endswith('staging.all-hands.dev')
|
||||
else request.url.hostname
|
||||
)
|
||||
|
||||
@@ -174,6 +176,20 @@ async def keycloak_callback(
|
||||
|
||||
email = user_info.get('email')
|
||||
user_id = user_info['sub']
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
user = await UserStore.create_user(user_id, user_info)
|
||||
|
||||
if not user:
|
||||
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': f'Failed to authenticate user {user_info["preferred_username"]}'
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f'Logging in user {str(user.id)} in org {user.current_org_id}')
|
||||
|
||||
# reCAPTCHA verification with Account Defender
|
||||
if RECAPTCHA_SITE_KEY:
|
||||
@@ -360,15 +376,7 @@ async def keycloak_callback(
|
||||
f'&state={state}'
|
||||
)
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(user_id)
|
||||
has_accepted_tos = (
|
||||
user_settings is not None and user_settings.accepted_tos is not None
|
||||
)
|
||||
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
encoded_redirect_url = quote(redirect_url, safe='')
|
||||
@@ -486,28 +494,20 @@ async def accept_tos(request: Request):
|
||||
redirect_url = body.get('redirect_url', str(request.base_url))
|
||||
|
||||
# Update user settings with TOS acceptance
|
||||
accepted_tos: datetime = datetime.now(timezone.utc)
|
||||
with session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.accepted_tos = datetime.now(timezone.utc)
|
||||
session.merge(user_settings)
|
||||
else:
|
||||
# Create user settings if they don't exist
|
||||
user_settings = UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
accepted_tos=datetime.now(timezone.utc),
|
||||
user_version=0, # This will trigger a migration to the latest version on next load
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
session.rollback()
|
||||
logger.error('User for {user_id} not found.')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User does not exist'},
|
||||
)
|
||||
session.add(user_settings)
|
||||
|
||||
user.accepted_tos = accepted_tos
|
||||
session.commit()
|
||||
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content={'redirect_url': redirect_url}
|
||||
|
||||
@@ -4,30 +4,22 @@ from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
import stripe
|
||||
from dateutil.relativedelta import relativedelta # type: ignore
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.config import get_config
|
||||
from server.constants import (
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
STRIPE_API_KEY,
|
||||
STRIPE_WEBHOOK_SECRET,
|
||||
SUBSCRIPTION_PRICE_DATA,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
@@ -106,31 +98,20 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
|
||||
return max(max_budget - spend, 0.0)
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current credit balance
|
||||
# Endpoint to retrieve the current organization's credit balance
|
||||
@billing_router.get('/credits')
|
||||
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
|
||||
if not stripe_service.STRIPE_API_KEY:
|
||||
return GetCreditsResponse()
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=15.0
|
||||
) as client:
|
||||
user_json = await _get_litellm_user(client, user_id)
|
||||
credits = calculate_credits(user_json['user_info'])
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f'litellm_get_user_failed: {type(e).__name__}: {e}',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': e.response.status_code,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve credit balance from billing service',
|
||||
)
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(user.current_org_id)
|
||||
)
|
||||
# Update to use calculate_credits
|
||||
spend = user_team_info.get('spend', 0)
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
|
||||
credits = max(max_budget - spend, 0)
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current subscription access
|
||||
@@ -165,79 +146,7 @@ async def get_subscription_access(
|
||||
async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await stripe_service.has_payment_method(user_id)
|
||||
|
||||
|
||||
# Endpoint to cancel user's subscription
|
||||
@billing_router.post('/cancel-subscription')
|
||||
async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONResponse:
|
||||
"""Cancel user's active subscription at the end of the current billing period."""
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
with session_maker() as session:
|
||||
# Find the user's active subscription
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not already cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription_access:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='No active subscription found',
|
||||
)
|
||||
|
||||
if not subscription_access.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot cancel subscription: missing Stripe subscription ID',
|
||||
)
|
||||
|
||||
try:
|
||||
# Cancel the subscription in Stripe at period end
|
||||
await stripe.Subscription.modify_async(
|
||||
subscription_access.stripe_subscription_id, cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Update local database
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
'end_at': subscription_access.end_at,
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{'status': 'success', 'message': 'Subscription cancelled successfully'}
|
||||
)
|
||||
|
||||
except stripe.StripeError as e:
|
||||
logger.error(
|
||||
'stripe_cancellation_failed',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f'Failed to cancel subscription: {str(e)}',
|
||||
)
|
||||
return await stripe_service.has_payment_method_by_user_id(user_id)
|
||||
|
||||
|
||||
# Endpoint to create a new setup intent in stripe
|
||||
@@ -246,16 +155,15 @@ async def create_customer_setup_session(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
|
||||
# Endpoint to create a new Stripe checkout session for credit purchase
|
||||
@@ -267,9 +175,9 @@ async def create_checkout_session(
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
customer=customer_info['customer_id'],
|
||||
line_items=[
|
||||
{
|
||||
'price_data': {
|
||||
@@ -282,7 +190,7 @@ async def create_checkout_session(
|
||||
'tax_behavior': 'exclusive',
|
||||
},
|
||||
'quantity': 1,
|
||||
}
|
||||
},
|
||||
],
|
||||
mode='payment',
|
||||
payment_method_types=['card'],
|
||||
@@ -295,8 +203,9 @@ async def create_checkout_session(
|
||||
logger.info(
|
||||
'created_stripe_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'stripe_customer_id': customer_info['customer_id'],
|
||||
'user_id': user_id,
|
||||
'org_id': customer_info['org_id'],
|
||||
'amount': body.amount,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
},
|
||||
@@ -305,105 +214,14 @@ async def create_checkout_session(
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
org_id=customer_info['org_id'],
|
||||
price=body.amount,
|
||||
price_code='NA',
|
||||
billing_session_type=BillingSessionType.DIRECT_PAYMENT.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@billing_router.post('/subscription-checkout-session')
|
||||
async def create_subscription_checkout_session(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
# Prevent duplicate subscriptions for the same user
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
existing_active_subscription = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_active_subscription:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot create subscription: User already has an active subscription that has not been cancelled',
|
||||
)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
subscription_price_data = SUBSCRIPTION_PRICE_DATA[billing_session_type.value]
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
line_items=[
|
||||
{
|
||||
'price_data': subscription_price_data,
|
||||
'quantity': 1,
|
||||
}
|
||||
],
|
||||
mode='subscription',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
subscription_data={
|
||||
'metadata': {
|
||||
'user_id': user_id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_subscription_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'user_id': user_id,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
price=subscription_price_data['unit_amount'],
|
||||
price_code='NA',
|
||||
billing_session_type=billing_session_type.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(
|
||||
redirect_url=typing.cast(str, checkout_session.url)
|
||||
)
|
||||
|
||||
|
||||
@billing_router.get('/create-subscription-checkout-session')
|
||||
async def create_subscription_checkout_session_via_get(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> RedirectResponse:
|
||||
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
|
||||
validate_saas_environment(request)
|
||||
|
||||
response = await create_subscription_checkout_session(
|
||||
request, billing_session_type, user_id
|
||||
)
|
||||
return RedirectResponse(response.redirect_url)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
|
||||
# Callback endpoint for successful Stripe payments - updates user credits and billing session status
|
||||
@@ -425,15 +243,6 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Any non direct payment (Subscription) is processed in the invoice_payment.paid by the webhook
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
!= BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
stripe_session = stripe.checkout.Session.retrieve(session_id)
|
||||
if stripe_session.status != 'complete':
|
||||
# Hopefully this never happens - we get a redirect from stripe where the payment is not yet complete
|
||||
@@ -447,31 +256,37 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
# Update max budget in litellm
|
||||
user_json = await _get_litellm_user(client, billing_session.user_id)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
new_max_budget = (
|
||||
(user_json.get('user_info') or {}).get('max_budget') or 0
|
||||
) + add_credits
|
||||
await _upsert_litellm_user(client, billing_session.user_id, new_max_budget)
|
||||
user = await UserStore.get_user_by_id_async(billing_session.user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = amount_subtotal
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
str(user.current_org_id), new_max_budget
|
||||
)
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = add_credits
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'org_id': str(user.current_org_id),
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=success', status_code=302
|
||||
@@ -501,206 +316,6 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
session.merge(billing_session)
|
||||
session.commit()
|
||||
|
||||
# Redirect credit purchases to billing screen, subscriptions to LLM settings
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
== BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=cancel',
|
||||
status_code=302,
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
# If no billing session found, default to LLM settings (subscription flow)
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
f'{request.base_url}settings/billing?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
|
||||
@billing_router.post('/stripe-webhook')
|
||||
async def stripe_webhook(request: Request) -> JSONResponse:
|
||||
"""Endpoint for stripe webhooks."""
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get('stripe-signature')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise HTTPException(status_code=400, detail=f'Invalid payload: {e}')
|
||||
except stripe.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise HTTPException(status_code=400, detail=f'Invalid signature: {e}')
|
||||
|
||||
# Handle the event
|
||||
logger.info('stripe_webhook_event', extra={'event': event})
|
||||
event_type = event['type']
|
||||
if event_type == 'invoice.paid':
|
||||
invoice = event['data']['object']
|
||||
amount_paid = invoice.amount_paid
|
||||
metadata = invoice.parent.subscription_details.metadata # type: ignore
|
||||
billing_session_type = metadata.billing_session_type
|
||||
assert (
|
||||
amount_paid == SUBSCRIPTION_PRICE_DATA[billing_session_type]['unit_amount']
|
||||
)
|
||||
user_id = metadata.user_id
|
||||
|
||||
start_at = datetime.now(UTC)
|
||||
if billing_session_type == BillingSessionType.MONTHLY_SUBSCRIPTION.value:
|
||||
end_at = start_at + relativedelta(months=1)
|
||||
else:
|
||||
raise ValueError(f'unknown_billing_session_type:{billing_session_type}')
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = SubscriptionAccess(
|
||||
status='ACTIVE',
|
||||
user_id=user_id,
|
||||
start_at=start_at,
|
||||
end_at=end_at,
|
||||
amount_paid=amount_paid,
|
||||
stripe_invoice_payment_id=invoice.payment_intent,
|
||||
stripe_subscription_id=invoice.subscription, # Store Stripe subscription ID
|
||||
)
|
||||
session.add(subscription_access)
|
||||
session.commit()
|
||||
elif event_type == 'customer.subscription.updated':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
# Handle subscription cancellation
|
||||
if subscription.get('cancel_at_period_end') is True:
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(
|
||||
SubscriptionAccess.stripe_subscription_id == subscription_id
|
||||
)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access and not subscription_access.cancelled_at:
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled_via_webhook',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
elif event_type == 'customer.subscription.deleted':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.stripe_subscription_id == subscription_id)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access:
|
||||
subscription_access.status = 'DISABLED'
|
||||
subscription_access.updated_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
# Reset user settings to free tier defaults
|
||||
reset_user_to_free_tier_settings(subscription_access.user_id)
|
||||
|
||||
logger.info(
|
||||
'subscription_expired_reset_to_free_tier',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info('stripe_webhook_unhandled_event_type', extra={'type': event_type})
|
||||
|
||||
return JSONResponse({'status': 'success'})
|
||||
|
||||
|
||||
def reset_user_to_free_tier_settings(user_id: str) -> None:
|
||||
"""Reset user settings to free tier defaults when subscription ends."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.llm_model = get_default_litellm_model()
|
||||
user_settings.llm_api_key = None
|
||||
user_settings.llm_api_key_for_byor = None
|
||||
user_settings.llm_base_url = LITE_LLM_API_URL
|
||||
user_settings.max_budget_per_task = None
|
||||
user_settings.confirmation_mode = False
|
||||
user_settings.enable_solvability_analysis = False
|
||||
user_settings.security_analyzer = 'llm'
|
||||
user_settings.agent = 'CodeActAgent'
|
||||
user_settings.language = 'en'
|
||||
user_settings.enable_default_condenser = True
|
||||
user_settings.enable_sound_notifications = False
|
||||
user_settings.enable_proactive_conversation_starters = True
|
||||
user_settings.user_consents_to_analytics = False
|
||||
|
||||
session.merge(user_settings)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_settings_reset_to_free_tier',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'reset_timestamp': datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _get_litellm_user(client: httpx.AsyncClient, user_id: str) -> dict:
|
||||
"""Get a user from litellm with the id matching that given.
|
||||
|
||||
If no such user exists, returns a dummy user in the format:
|
||||
`{'user_id': '<USER_ID>', 'user_info': {'spend': 0}, 'keys': [], 'teams': []}`
|
||||
"""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
async def _upsert_litellm_user(
|
||||
client: httpx.AsyncClient, user_id: str, max_budget: float
|
||||
):
|
||||
"""Insert / Update a user in litellm."""
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
json={
|
||||
'user_id': user_id,
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -5,8 +5,8 @@ from threading import Thread
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from sqlalchemy import func, select
|
||||
from storage.database import a_session_maker, engine, session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.database import a_session_maker, get_engine, session_maker
|
||||
from storage.user import User
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import wait_all
|
||||
@@ -47,6 +47,7 @@ def add_debugging_routes(api: FastAPI):
|
||||
- checked_out: Number of connections currently in use
|
||||
- overflow: Number of overflow connections created beyond pool_size
|
||||
"""
|
||||
engine = get_engine()
|
||||
return {
|
||||
'checked_in': engine.pool.checkedin(),
|
||||
'checked_out': engine.pool.checkedout(),
|
||||
@@ -127,8 +128,9 @@ def _db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
with session_maker() as session:
|
||||
num_users = session.query(UserSettings).count()
|
||||
num_users = session.query(User).count()
|
||||
time.sleep(delay)
|
||||
engine = get_engine()
|
||||
logger.info(
|
||||
'check',
|
||||
extra={
|
||||
@@ -155,7 +157,7 @@ async def _a_db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
async with a_session_maker() as a_session:
|
||||
stmt = select(func.count(UserSettings.id))
|
||||
stmt = select(func.count(User.id))
|
||||
num_users = await a_session.execute(stmt)
|
||||
await asyncio.sleep(delay)
|
||||
logger.info(f'a_num_users:{num_users.scalar_one()}')
|
||||
|
||||
@@ -21,7 +21,7 @@ from server.utils.conversation_callback_utils import (
|
||||
update_conversation_stats,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.server.shared import conversation_manager
|
||||
|
||||
@@ -226,12 +226,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]:
|
||||
|
||||
def _get_user_id(conversation_id: str) -> str:
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
return conversation_metadata.user_id
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
|
||||
|
||||
async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None:
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.database import session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.server.shared import file_store
|
||||
@@ -33,10 +33,10 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
def _verify_conversation():
|
||||
with session_maker() as session:
|
||||
metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -5,15 +7,16 @@ import uuid
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
|
||||
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from integrations.jira.jira_manager import JiraManager
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.utils import HOST_URL
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from server.auth.constants import JIRA_CLIENT_ID, JIRA_CLIENT_SECRET
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import WEB_HOST
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
from storage.redis import create_redis_client
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -24,7 +27,7 @@ JIRA_WEBHOOKS_ENABLED = os.environ.get('JIRA_WEBHOOKS_ENABLED', '0') in (
|
||||
'1',
|
||||
'true',
|
||||
)
|
||||
JIRA_REDIRECT_URI = f'https://{WEB_HOST}/integration/jira/callback'
|
||||
JIRA_REDIRECT_URI = f'{HOST_URL}/integration/jira/callback'
|
||||
JIRA_SCOPES = 'read:me read:jira-user read:jira-work'
|
||||
JIRA_AUTH_URL = 'https://auth.atlassian.com/authorize'
|
||||
JIRA_TOKEN_URL = 'https://auth.atlassian.com/oauth/token'
|
||||
@@ -122,6 +125,63 @@ jira_manager = JiraManager(token_manager)
|
||||
redis_client = create_redis_client()
|
||||
|
||||
|
||||
async def verify_jira_signature(body: bytes, signature: str, payload: dict):
|
||||
"""
|
||||
Verify Jira webhook signature.
|
||||
|
||||
Args:
|
||||
body: Raw request body bytes
|
||||
signature: Signature from x-hub-signature header (format: "sha256=<hash>")
|
||||
payload: Parsed JSON payload from webhook
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if signature verification fails or workspace is invalid
|
||||
|
||||
Returns:
|
||||
None (raises exception on failure)
|
||||
"""
|
||||
|
||||
if not signature:
|
||||
raise HTTPException(
|
||||
status_code=403, detail='x-hub-signature header is missing!'
|
||||
)
|
||||
|
||||
workspace_name = jira_manager.get_workspace_name_from_payload(payload)
|
||||
if workspace_name is None:
|
||||
logger.warning('[Jira] No workspace name found in webhook payload')
|
||||
raise HTTPException(
|
||||
status_code=403, detail='Workspace name not found in payload'
|
||||
)
|
||||
|
||||
workspace: (
|
||||
JiraWorkspace | None
|
||||
) = await jira_manager.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if workspace is None:
|
||||
logger.warning(f'[Jira] Could not identify workspace {workspace_name}')
|
||||
raise HTTPException(status_code=403, detail='Unidentified workspace')
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(
|
||||
'[Jira] Workspace is inactive',
|
||||
extra={
|
||||
'jira_workspace_id': workspace.id,
|
||||
'parsed_workspace_name': workspace.name,
|
||||
'status': workspace.status,
|
||||
},
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=403, detail='Workspace is inactive')
|
||||
|
||||
webhook_secret = token_manager.decrypt_text(workspace.webhook_secret)
|
||||
expected_signature = hmac.new(
|
||||
webhook_secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(expected_signature, signature):
|
||||
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
|
||||
|
||||
|
||||
async def _handle_workspace_link_creation(
|
||||
user_id: str, jira_user_id: str, target_workspace: str
|
||||
):
|
||||
@@ -216,6 +276,7 @@ async def _validate_workspace_update_permissions(user_id: str, target_workspace:
|
||||
async def jira_events(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
x_hub_signature: str = Header(None),
|
||||
):
|
||||
"""Handle Jira webhook events."""
|
||||
# Check if Jira webhooks are enabled
|
||||
@@ -227,13 +288,15 @@ async def jira_events(
|
||||
)
|
||||
|
||||
try:
|
||||
signature_valid, signature, payload = await jira_manager.validate_request(
|
||||
request
|
||||
)
|
||||
parts = x_hub_signature.split('=', 1)
|
||||
if not (len(parts) == 2 and parts[1]):
|
||||
raise HTTPException(status_code=403, detail='Malformed x-hub-signature!')
|
||||
|
||||
if not signature_valid:
|
||||
logger.warning('[Jira] Invalid webhook signature')
|
||||
raise HTTPException(status_code=403, detail='Invalid webhook signature!')
|
||||
signature = parts[1]
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
|
||||
await verify_jira_signature(body, signature, payload)
|
||||
|
||||
# Check for duplicate requests using Redis
|
||||
key = f'jira:{signature}'
|
||||
|
||||
@@ -15,7 +15,6 @@ from integrations.slack.slack_manager import SlackManager
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import (
|
||||
KEYCLOAK_CLIENT_ID,
|
||||
KEYCLOAK_REALM_NAME,
|
||||
@@ -35,6 +34,7 @@ from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
from storage.slack_user import SlackUser
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.shared import config, sio
|
||||
@@ -79,6 +79,14 @@ async def install_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if not config.jwt_secret:
|
||||
logger.error('slack_install_callback_error JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
try:
|
||||
client = AsyncWebClient() # no prepared token needed for this
|
||||
# Complete the installation by calling oauth.v2.access API method
|
||||
@@ -94,16 +102,17 @@ async def install_callback(
|
||||
|
||||
# Create a state variable for keycloak oauth
|
||||
payload = {}
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if state:
|
||||
payload = jwt.decode(
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
payload['slack_user_id'] = authed_user.get('id')
|
||||
payload['bot_access_token'] = bot_access_token
|
||||
payload['team_id'] = team_id
|
||||
|
||||
state = jwt.encode(payload, jwt_secret.get_secret_value(), algorithm='HS256')
|
||||
state = jwt.encode(
|
||||
payload, config.jwt_secret.get_secret_value(), algorithm='HS256'
|
||||
)
|
||||
|
||||
# Redirect into keycloak
|
||||
scope = quote('openid email profile offline_access')
|
||||
@@ -149,9 +158,16 @@ async def keycloak_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if not config.jwt_secret:
|
||||
logger.error('problem_retrieving_keycloak_tokens JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
payload: dict[str, str] = jwt.decode(
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
slack_user_id = payload['slack_user_id']
|
||||
bot_access_token = payload['bot_access_token']
|
||||
@@ -180,6 +196,13 @@ async def keycloak_callback(
|
||||
|
||||
user_info = await token_manager.get_user_info(keycloak_access_token)
|
||||
keycloak_user_id = user_info['sub']
|
||||
user = await UserStore.get_user_by_id_async(keycloak_user_id)
|
||||
if not user:
|
||||
return _html_response(
|
||||
title='Failed to authenticate.',
|
||||
description=f'Please re-login into <a href="{HOST_URL}" style="color:#ecedee;text-decoration:underline;">OpenHands Cloud</a>. Then try <a href="https://docs.all-hands.dev/usage/cloud/slack-installation" style="color:#ecedee;text-decoration:underline;">installing the OpenHands Slack App</a> again',
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# These tokens are offline access tokens - store them!
|
||||
await token_manager.store_offline_token(keycloak_user_id, keycloak_refresh_token)
|
||||
@@ -211,6 +234,7 @@ async def keycloak_callback(
|
||||
slack_display_name = slack_user_info.data['user']['profile']['display_name']
|
||||
slack_user = SlackUser(
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
org_id=user.current_org_id,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_display_name=slack_display_name,
|
||||
)
|
||||
@@ -305,7 +329,7 @@ async def on_form_interaction(request: Request, background_tasks: BackgroundTask
|
||||
|
||||
body = await request.body()
|
||||
form = await request.form()
|
||||
payload = json.loads(form.get('payload')) # type: ignore[arg-type]
|
||||
payload = json.loads(form.get('payload'))
|
||||
|
||||
logger.info('slack_on_form_interaction', extra={'payload': payload})
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ DEVICE_CODE_EXPIRES_IN = 600 # 10 minutes
|
||||
DEVICE_TOKEN_POLL_INTERVAL = 5 # seconds
|
||||
|
||||
API_KEY_NAME = 'Device Link Access Key'
|
||||
KEY_EXPIRATION_TIME = timedelta(days=1) # Key expires in 24 hours
|
||||
KEY_EXPIRATION_TIME = timedelta(days=7) # Key expires in a week
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Models
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class OrgCreationError(Exception):
|
||||
"""Base exception for organization creation errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OrgNameExistsError(OrgCreationError):
|
||||
"""Raised when an organization name already exists."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
super().__init__(f'Organization with name "{name}" already exists')
|
||||
|
||||
|
||||
class LiteLLMIntegrationError(OrgCreationError):
|
||||
"""Raised when LiteLLM integration fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OrgDatabaseError(OrgCreationError):
|
||||
"""Raised when database operations fail."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OrgCreate(BaseModel):
|
||||
"""Request model for creating a new organization."""
|
||||
|
||||
# Required fields
|
||||
name: str = Field(min_length=1, max_length=255, strip_whitespace=True)
|
||||
contact_name: str
|
||||
contact_email: EmailStr = Field(strip_whitespace=True)
|
||||
|
||||
|
||||
class OrgResponse(BaseModel):
|
||||
"""Response model for organization."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
contact_name: str
|
||||
contact_email: str
|
||||
conversation_expiration: int | None = None
|
||||
agent: str | None = None
|
||||
default_max_iterations: int | None = None
|
||||
security_analyzer: str | None = None
|
||||
confirmation_mode: bool | None = None
|
||||
default_llm_model: str | None = None
|
||||
default_llm_api_key_for_byor: str | None = None
|
||||
default_llm_base_url: str | None = None
|
||||
remote_runtime_resource_factor: int | None = None
|
||||
enable_default_condenser: bool = True
|
||||
billing_margin: float | None = None
|
||||
enable_proactive_conversation_starters: bool = True
|
||||
sandbox_base_container_image: str | None = None
|
||||
sandbox_runtime_container_image: str | None = None
|
||||
org_version: int = 0
|
||||
mcp_config: dict | None = None
|
||||
search_api_key: str | None = None
|
||||
sandbox_api_key: str | None = None
|
||||
max_budget_per_task: float | None = None
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
credits: float | None = None
|
||||
@@ -0,0 +1,117 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgCreate,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
OrgResponse,
|
||||
)
|
||||
from storage.org_service import OrgService
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
# Initialize API router
|
||||
org_router = APIRouter(prefix='/api/organizations')
|
||||
|
||||
|
||||
@org_router.post('', response_model=OrgResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_org(
|
||||
org_data: OrgCreate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Create a new organization.
|
||||
|
||||
This endpoint allows authenticated users with @openhands.dev email to create
|
||||
a new organization. The user who creates the organization automatically becomes
|
||||
its owner.
|
||||
|
||||
Args:
|
||||
org_data: Organization creation data
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The created organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user email domain is not @openhands.dev
|
||||
HTTPException: 409 if organization name already exists
|
||||
HTTPException: 500 if creation fails
|
||||
"""
|
||||
logger.info(
|
||||
'Creating new organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_name': org_data.name,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to create organization
|
||||
org = await OrgService.create_org_with_owner(
|
||||
name=org_data.name,
|
||||
contact_name=org_data.contact_name,
|
||||
contact_email=org_data.contact_email,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse(
|
||||
id=str(org.id),
|
||||
name=org.name,
|
||||
contact_name=org.contact_name,
|
||||
contact_email=org.contact_email,
|
||||
conversation_expiration=org.conversation_expiration,
|
||||
agent=org.agent,
|
||||
default_max_iterations=org.default_max_iterations,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
default_llm_model=org.default_llm_model,
|
||||
default_llm_base_url=org.default_llm_base_url,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
org_version=org.org_version,
|
||||
mcp_config=org.mcp_config,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
)
|
||||
except OrgNameExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
)
|
||||
except LiteLLMIntegrationError as e:
|
||||
logger.error(
|
||||
'LiteLLM integration failed',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create LiteLLM integration',
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database operation failed',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error creating organization',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
@@ -23,6 +23,7 @@ from sqlalchemy import orm
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
@@ -646,16 +647,18 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
"""
|
||||
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation_metadata:
|
||||
if not conversation_metadata_saas:
|
||||
raise ValueError(f'No conversation found {conversation_id}')
|
||||
|
||||
return conversation_metadata.user_id
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
|
||||
async def _get_runtime_status_from_nested_runtime(
|
||||
self, session_api_key: Any | None, nested_url: str, conversation_id: str
|
||||
@@ -994,9 +997,17 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
with session_maker() as session:
|
||||
# Only include conversations updated in the past week
|
||||
one_week_ago = datetime.now(UTC) - timedelta(days=7)
|
||||
query = session.query(StoredConversationMetadata.conversation_id).filter(
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
query = (
|
||||
session.query(StoredConversationMetadata.conversation_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
)
|
||||
)
|
||||
user_conversation_ids = set(query)
|
||||
return user_conversation_ids
|
||||
@@ -1070,11 +1081,16 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
if conversation_metadata is None:
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
if conversation_metadata is None or conversation_metadata_saas is None:
|
||||
# Conversation is running in different server
|
||||
return
|
||||
|
||||
user_id = conversation_metadata.user_id
|
||||
user_id = conversation_metadata_saas.user_id
|
||||
|
||||
# Get the id of the next event which is not present
|
||||
events_dir = get_conversation_events_dir(
|
||||
|
||||
@@ -241,7 +241,7 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
|
||||
return SharedConversation(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=stored.user_id if stored.user_id else None,
|
||||
created_by_user_id=None, # user_id is no longer stored in conversation metadata
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
|
||||
@@ -0,0 +1,350 @@
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy import func, select
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user import User
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
||||
AppConversationInfoService,
|
||||
AppConversationInfoServiceInjector,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
AppConversationInfoPage,
|
||||
AppConversationSortOrder,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
|
||||
|
||||
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
|
||||
|
||||
async def _secure_select(self):
|
||||
query = (
|
||||
select(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def _secure_select_with_saas_metadata(self):
|
||||
"""Select query that includes SAAS metadata for retrieving user_id."""
|
||||
query = (
|
||||
select(StoredConversationMetadata, StoredConversationMetadataSaas)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def search_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> AppConversationInfoPage:
|
||||
"""Search for conversations with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
# Exclude sub-conversations (only include top-level conversations)
|
||||
query = query.where(
|
||||
StoredConversationMetadata.parent_conversation_id.is_(None)
|
||||
)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
# Add sort order
|
||||
if sort_order == AppConversationSortOrder.CREATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.created_at)
|
||||
elif sort_order == AppConversationSortOrder.CREATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.TITLE:
|
||||
query = query.order_by(StoredConversationMetadata.title)
|
||||
elif sort_order == AppConversationSortOrder.TITLE_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.title.desc())
|
||||
|
||||
# Apply pagination
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Apply limit and get one extra to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
for stored_metadata, saas_metadata in rows
|
||||
]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return AppConversationInfoPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def count_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count conversations matching the given filters with SAAS metadata."""
|
||||
query = (
|
||||
select(func.count(StoredConversationMetadata.conversation_id))
|
||||
.select_from(
|
||||
StoredConversationMetadata.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
# Apply user filtering
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
count = result.scalar()
|
||||
return count or 0
|
||||
|
||||
def _apply_filters_with_saas_metadata(
|
||||
self,
|
||||
query,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
):
|
||||
"""Apply filters to query that includes SAAS metadata."""
|
||||
# Apply the same filters as the base class
|
||||
conditions = []
|
||||
if title__contains is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.title.like(f'%{title__contains}%')
|
||||
)
|
||||
|
||||
if created_at__gte is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at >= created_at__gte)
|
||||
|
||||
if created_at__lt is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at < created_at__lt)
|
||||
|
||||
if updated_at__gte is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
||||
)
|
||||
|
||||
if updated_at__lt is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||
)
|
||||
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
return query
|
||||
|
||||
async def get_app_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> AppConversationInfo | None:
|
||||
"""Get conversation info with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
result_set = await self.db_session.execute(query)
|
||||
result = result_set.first()
|
||||
if result:
|
||||
stored_metadata, saas_metadata = result
|
||||
return self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
return None
|
||||
|
||||
async def batch_get_app_conversation_info(
|
||||
self, conversation_ids: list[UUID]
|
||||
) -> list[AppConversationInfo | None]:
|
||||
"""Batch get conversation info with user_id from SAAS metadata."""
|
||||
conversation_id_strs = [
|
||||
str(conversation_id) for conversation_id in conversation_ids
|
||||
]
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id.in_(conversation_id_strs)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Create a mapping of conversation_id to (metadata, saas_metadata)
|
||||
info_by_id = {}
|
||||
for stored_metadata, saas_metadata in rows:
|
||||
info_by_id[stored_metadata.conversation_id] = (
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
)
|
||||
|
||||
results: list[AppConversationInfo | None] = []
|
||||
for conversation_id in conversation_id_strs:
|
||||
if conversation_id in info_by_id:
|
||||
stored_metadata, saas_metadata = info_by_id[conversation_id]
|
||||
results.append(
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
)
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
return results
|
||||
|
||||
async def save_app_conversation_info(
|
||||
self, info: AppConversationInfo
|
||||
) -> AppConversationInfo:
|
||||
"""Save conversation info and create/update SAAS metadata with user_id and org_id."""
|
||||
# Save the base conversation metadata
|
||||
await super().save_app_conversation_info(info)
|
||||
|
||||
# Get current user_id for SAAS metadata
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
# Convert string user_id to UUID
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
user_query = select(User).where(User.id == user_id_uuid)
|
||||
result = await self.db_session.execute(user_query)
|
||||
user = result.scalar_one_or_none()
|
||||
assert user
|
||||
|
||||
# Check if SAAS metadata already exists
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(info.id)
|
||||
)
|
||||
result = await self.db_session.execute(saas_query)
|
||||
existing_saas_metadata = result.scalar_one_or_none()
|
||||
assert existing_saas_metadata is None or (
|
||||
existing_saas_metadata.user_id == user_id_uuid
|
||||
and existing_saas_metadata.org_id == user.current_org_id
|
||||
)
|
||||
|
||||
if not existing_saas_metadata:
|
||||
# Create new SAAS metadata
|
||||
# Set org_id to user_id as specified in requirements
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=str(info.id),
|
||||
user_id=user_id_uuid,
|
||||
org_id=user.current_org_id,
|
||||
)
|
||||
self.db_session.add(saas_metadata)
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
return info
|
||||
|
||||
def _to_info_with_user_id(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas,
|
||||
) -> AppConversationInfo:
|
||||
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
|
||||
# Use the base _to_info method to get the basic info
|
||||
info = self._to_info(stored)
|
||||
|
||||
# Override the created_by_user_id with the user_id from SAAS metadata
|
||||
info.created_by_user_id = (
|
||||
str(saas_metadata.user_id) if saas_metadata.user_id else None
|
||||
)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class SaasAppConversationInfoServiceInjector(AppConversationInfoServiceInjector):
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[AppConversationInfoService, None]:
|
||||
from openhands.app_server.config import (
|
||||
get_db_session,
|
||||
get_user_context,
|
||||
)
|
||||
|
||||
async with (
|
||||
get_user_context(state, request) as user_context,
|
||||
get_db_session(state, request) as db_session,
|
||||
):
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=db_session, user_context=user_context
|
||||
)
|
||||
yield service
|
||||
@@ -0,0 +1,85 @@
|
||||
from storage.api_key import ApiKey
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.conversation_callback import CallbackStatus, ConversationCallback
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.experiment_assignment import ExperimentAssignment
|
||||
from storage.feedback import ConversationFeedback, Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_dc_conversation import JiraDcConversation
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
from storage.linear_conversation import LinearConversation
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
from storage.role import Role
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_team import SlackTeam
|
||||
from storage.slack_user import SlackUser
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.subscription_access_status import SubscriptionAccessStatus
|
||||
from storage.user import User
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
__all__ = [
|
||||
'ApiKey',
|
||||
'AuthTokens',
|
||||
'BillingSession',
|
||||
'BillingSessionType',
|
||||
'CallbackStatus',
|
||||
'ConversationCallback',
|
||||
'ConversationFeedback',
|
||||
'StoredConversationMetadataSaas',
|
||||
'ConversationWork',
|
||||
'ExperimentAssignment',
|
||||
'Feedback',
|
||||
'GithubAppInstallation',
|
||||
'GitlabWebhook',
|
||||
'JiraConversation',
|
||||
'JiraDcConversation',
|
||||
'JiraDcUser',
|
||||
'JiraDcWorkspace',
|
||||
'JiraUser',
|
||||
'JiraWorkspace',
|
||||
'LinearConversation',
|
||||
'LinearUser',
|
||||
'LinearWorkspace',
|
||||
'MaintenanceTask',
|
||||
'MaintenanceTaskStatus',
|
||||
'OpenhandsPR',
|
||||
'Org',
|
||||
'OrgMember',
|
||||
'ProactiveConversation',
|
||||
'Role',
|
||||
'SlackConversation',
|
||||
'SlackTeam',
|
||||
'SlackUser',
|
||||
'StoredConversationMetadata',
|
||||
'StoredOfflineToken',
|
||||
'StoredRepository',
|
||||
'StoredCustomSecrets',
|
||||
'StripeCustomer',
|
||||
'SubscriptionAccess',
|
||||
'SubscriptionAccessStatus',
|
||||
'User',
|
||||
'UserRepositoryMap',
|
||||
'UserSettings',
|
||||
'WebhookStatus',
|
||||
]
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -11,9 +13,13 @@ class ApiKey(Base):
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
key = Column(String(255), nullable=False, unique=True, index=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
name = Column(String(255), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
)
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='api_keys')
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.api_key import ApiKey
|
||||
from storage.database import session_maker
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -39,10 +40,15 @@ class ApiKeyStore:
|
||||
The generated API key
|
||||
"""
|
||||
api_key = self.generate_api_key()
|
||||
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key, user_id=user_id, name=name, expires_at=expires_at
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(key_record)
|
||||
session.commit()
|
||||
@@ -108,8 +114,15 @@ class ApiKeyStore:
|
||||
|
||||
def list_api_keys(self, user_id: str) -> list[dict]:
|
||||
"""List all API keys for a user."""
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys = session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
keys = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
@@ -124,9 +137,14 @@ class ApiKeyStore:
|
||||
]
|
||||
|
||||
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
for key in keys:
|
||||
if key.name == 'MCP_API_KEY':
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, String
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -11,9 +13,9 @@ class BillingSession(Base): # type: ignore
|
||||
"""
|
||||
|
||||
__tablename__ = 'billing_sessions'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
status = Column(
|
||||
Enum(
|
||||
'in_progress',
|
||||
@@ -24,15 +26,6 @@ class BillingSession(Base): # type: ignore
|
||||
),
|
||||
default='in_progress',
|
||||
)
|
||||
billing_session_type = Column(
|
||||
Enum(
|
||||
'DIRECT_PAYMENT',
|
||||
'MONTHLY_SUBSCRIPTION',
|
||||
name='billing_session_type_enum',
|
||||
),
|
||||
nullable=False,
|
||||
default='DIRECT_PAYMENT',
|
||||
)
|
||||
price = Column(DECIMAL(19, 4), nullable=False)
|
||||
price_code = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -43,3 +36,6 @@ class BillingSession(Base): # type: ignore
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC), # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='billing_sessions')
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -10,7 +15,6 @@ from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text, text
|
||||
from sqlalchemy import Enum as SQLEnum
|
||||
from storage.base import Base
|
||||
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
@@ -33,7 +37,7 @@ class ConversationCallbackProcessor(BaseModel, ABC):
|
||||
async def __call__(
|
||||
self,
|
||||
callback: ConversationCallback,
|
||||
observation: AgentStateChangedObservation,
|
||||
observation: 'AgentStateChangedObservation',
|
||||
) -> None:
|
||||
"""
|
||||
Process a conversation event.
|
||||
|
||||
+26
-110
@@ -1,122 +1,38 @@
|
||||
import asyncio
|
||||
import os
|
||||
"""
|
||||
Database connection module for enterprise storage.
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.util import await_only
|
||||
This is for backwards compatibility with V0.
|
||||
|
||||
DB_HOST = os.environ.get('DB_HOST', 'localhost') # for non-GCP environments
|
||||
DB_PORT = os.environ.get('DB_PORT', '5432') # for non-GCP environments
|
||||
DB_USER = os.environ.get('DB_USER', 'postgres')
|
||||
DB_PASS = os.environ.get('DB_PASS', 'postgres').strip()
|
||||
DB_NAME = os.environ.get('DB_NAME', 'openhands')
|
||||
This module provides database engines and session makers by delegating to the
|
||||
centralized DbSessionInjector from app_server/config.py. This ensures a single
|
||||
source of truth for database connection configuration.
|
||||
"""
|
||||
|
||||
GCP_DB_INSTANCE = os.environ.get('GCP_DB_INSTANCE') # for GCP environments
|
||||
GCP_PROJECT = os.environ.get('GCP_PROJECT')
|
||||
GCP_REGION = os.environ.get('GCP_REGION')
|
||||
|
||||
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
|
||||
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
|
||||
POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800'))
|
||||
|
||||
# Initialize Cloud SQL Connector once at module level for GCP environments.
|
||||
_connector = None
|
||||
import contextlib
|
||||
|
||||
|
||||
def _get_db_engine():
|
||||
if GCP_DB_INSTANCE: # GCP environments
|
||||
def _get_db_session_injector():
|
||||
from openhands.app_server.config import get_global_config
|
||||
|
||||
def get_db_connection():
|
||||
global _connector
|
||||
from google.cloud.sql.connector import Connector
|
||||
|
||||
if not _connector:
|
||||
_connector = Connector()
|
||||
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
|
||||
return _connector.connect(
|
||||
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
|
||||
)
|
||||
|
||||
return create_engine(
|
||||
'postgresql+pg8000://',
|
||||
creator=get_db_connection,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
else:
|
||||
host_string = (
|
||||
f'postgresql+pg8000://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
)
|
||||
return create_engine(
|
||||
host_string,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
_config = get_global_config()
|
||||
return _config.db_session
|
||||
|
||||
|
||||
async def async_creator():
|
||||
from google.cloud.sql.connector import Connector
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
async with Connector(loop=loop) as connector:
|
||||
conn = await connector.connect_async(
|
||||
f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}', # Cloud SQL instance connection name"
|
||||
'asyncpg',
|
||||
user=DB_USER,
|
||||
password=DB_PASS,
|
||||
db=DB_NAME,
|
||||
)
|
||||
return conn
|
||||
def session_maker():
|
||||
db_session_injector = _get_db_session_injector()
|
||||
session_maker = db_session_injector.get_session_maker()
|
||||
return session_maker()
|
||||
|
||||
|
||||
def _get_async_db_engine():
|
||||
if GCP_DB_INSTANCE: # GCP environments
|
||||
|
||||
def adapted_creator():
|
||||
dbapi = engine.dialect.dbapi
|
||||
from sqlalchemy.dialects.postgresql.asyncpg import (
|
||||
AsyncAdapt_asyncpg_connection,
|
||||
)
|
||||
|
||||
return AsyncAdapt_asyncpg_connection(
|
||||
dbapi,
|
||||
await_only(async_creator()),
|
||||
prepared_statement_cache_size=100,
|
||||
)
|
||||
|
||||
# create async connection pool with wrapped creator
|
||||
return create_async_engine(
|
||||
'postgresql+asyncpg://',
|
||||
creator=adapted_creator,
|
||||
# Use NullPool to disable connection pooling and avoid event loop issues
|
||||
poolclass=NullPool,
|
||||
)
|
||||
else:
|
||||
host_string = (
|
||||
f'postgresql+asyncpg://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
)
|
||||
return create_async_engine(
|
||||
host_string,
|
||||
# Use NullPool to disable connection pooling and avoid event loop issues
|
||||
poolclass=NullPool,
|
||||
)
|
||||
@contextlib.asynccontextmanager
|
||||
async def a_session_maker():
|
||||
db_session_injector = _get_db_session_injector()
|
||||
a_session_maker = await db_session_injector.get_async_session_maker()
|
||||
async with a_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
engine = _get_db_engine()
|
||||
session_maker = sessionmaker(bind=engine)
|
||||
|
||||
a_engine = _get_async_db_engine()
|
||||
a_session_maker = sessionmaker(
|
||||
bind=a_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
# Configure the session to use the same connection for all operations in a transaction
|
||||
# This helps prevent the "Task got Future attached to a different loop" error
|
||||
future=True,
|
||||
)
|
||||
def get_engine():
|
||||
db_session_injector = _get_db_session_injector()
|
||||
engine = db_session_injector.get_db_engine()
|
||||
return engine
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
import binascii
|
||||
import hashlib
|
||||
from base64 import b64decode, b64encode
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from pydantic import SecretStr
|
||||
from server.config import get_config
|
||||
|
||||
_jwt_service = None
|
||||
_fernet = None
|
||||
|
||||
|
||||
def encrypt_model(encrypt_keys: list, model_instance) -> dict:
|
||||
return encrypt_kwargs(encrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def decrypt_model(decrypt_keys: list, model_instance) -> dict:
|
||||
return decrypt_kwargs(decrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def encrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if isinstance(value, dict):
|
||||
encrypt_kwargs(encrypt_keys, value)
|
||||
continue
|
||||
|
||||
if key in encrypt_keys:
|
||||
value = encrypt_value(value)
|
||||
kwargs[key] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def decrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = decrypt_value(value)
|
||||
kwargs[key] = value
|
||||
except binascii.Error:
|
||||
pass # Key is in legacy format...
|
||||
return kwargs
|
||||
|
||||
|
||||
def encrypt_value(value: str | SecretStr) -> str:
|
||||
return get_jwt_service().create_jwe_token(
|
||||
{'v': value.get_secret_value() if isinstance(value, SecretStr) else value}
|
||||
)
|
||||
|
||||
|
||||
def decrypt_value(value: str | SecretStr) -> str:
|
||||
token = get_jwt_service().decrypt_jwe_token(
|
||||
value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
)
|
||||
return token['v']
|
||||
|
||||
|
||||
def get_jwt_service():
|
||||
from openhands.app_server.config import get_global_config
|
||||
|
||||
global _jwt_service
|
||||
if _jwt_service is None:
|
||||
jwt_service_injector = get_global_config().jwt
|
||||
assert jwt_service_injector is not None
|
||||
_jwt_service = jwt_service_injector.get_jwt_service()
|
||||
return _jwt_service
|
||||
|
||||
|
||||
def decrypt_legacy_model(decrypt_keys: list, model_instance) -> dict:
|
||||
return decrypt_legacy_kwargs(decrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def decrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = decrypt_legacy_value(value)
|
||||
kwargs[key] = value
|
||||
except binascii.Error:
|
||||
pass # Key is in legacy format...
|
||||
except InvalidToken:
|
||||
pass # Key not encrypted...
|
||||
return kwargs
|
||||
|
||||
|
||||
def decrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
if isinstance(value, SecretStr):
|
||||
return (
|
||||
get_fernet().decrypt(b64decode(value.get_secret_value().encode())).decode()
|
||||
)
|
||||
else:
|
||||
return get_fernet().decrypt(b64decode(value.encode())).decode()
|
||||
|
||||
|
||||
def get_fernet():
|
||||
global _fernet
|
||||
if _fernet is None:
|
||||
jwt_secret = get_config().jwt_secret.get_secret_value()
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
_fernet = Fernet(fernet_key)
|
||||
return _fernet
|
||||
|
||||
|
||||
def model_to_kwargs(model_instance):
|
||||
return {
|
||||
column.name: getattr(model_instance, column.name)
|
||||
for column in model_instance.__table__.columns
|
||||
}
|
||||
@@ -1,7 +1,16 @@
|
||||
import sys
|
||||
from enum import IntEnum
|
||||
|
||||
from sqlalchemy import ARRAY, Boolean, Column, DateTime, Integer, String, Text, text
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
|
||||
@@ -118,9 +118,7 @@ class JiraIntegrationStore:
|
||||
.first()
|
||||
)
|
||||
|
||||
async def get_workspace_by_name(
|
||||
self, workspace_name: str
|
||||
) -> Optional[JiraWorkspace]:
|
||||
async def get_workspace_by_name(self, workspace_name: str) -> JiraWorkspace | None:
|
||||
"""Retrieve workspace by name."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
|
||||
@@ -0,0 +1,863 @@
|
||||
"""
|
||||
Store class for managing organizational settings.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# Timeout in seconds for BYOR key verification requests to LiteLLM
|
||||
BYOR_KEY_VERIFICATION_TIMEOUT = 5.0
|
||||
|
||||
# A very large number to represent "unlimited" until LiteLLM fixes their unlimited update bug.
|
||||
UNLIMITED_BUDGET_SETTING = 1000000000.0
|
||||
|
||||
|
||||
class LiteLlmManager:
|
||||
"""Manage LiteLLM interactions."""
|
||||
|
||||
@staticmethod
|
||||
async def create_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
oss_settings: Settings,
|
||||
create_user: bool,
|
||||
) -> Settings | None:
|
||||
logger.info(
|
||||
'SettingsStore:update_settings_with_litellm_default:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(keycloak_user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
)
|
||||
|
||||
if create_user:
|
||||
await LiteLlmManager._create_user(
|
||||
client, keycloak_user_info.get('email'), keycloak_user_id
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
)
|
||||
|
||||
key = await LiteLlmManager._generate_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}',
|
||||
None,
|
||||
)
|
||||
|
||||
oss_settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
oss_settings.llm_model = get_default_litellm_model()
|
||||
oss_settings.llm_api_key = SecretStr(key)
|
||||
oss_settings.llm_base_url = LITE_LLM_API_URL
|
||||
return oss_settings
|
||||
|
||||
@staticmethod
|
||||
async def migrate_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
user_settings: UserSettings,
|
||||
) -> UserSettings | None:
|
||||
logger.info(
|
||||
'SettingsStore:umigrate_lite_llm_entries:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
user_json = await LiteLlmManager._get_user(client, keycloak_user_id)
|
||||
if not user_json:
|
||||
return None
|
||||
user_info = user_json['user_info']
|
||||
max_budget = user_info.get('max_budget', 0.0)
|
||||
spend = user_info.get('spend', 0.0)
|
||||
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
|
||||
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
|
||||
if (
|
||||
user_settings
|
||||
and user_settings.user_version < 4
|
||||
and user_settings.billing_margin
|
||||
and user_settings.billing_margin != 1.0
|
||||
):
|
||||
billing_margin = user_settings.billing_margin
|
||||
logger.info(
|
||||
'user_settings_v4_budget_upgrade',
|
||||
extra={
|
||||
'max_budget': max_budget,
|
||||
'billing_margin': billing_margin,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
max_budget *= billing_margin
|
||||
spend *= billing_margin
|
||||
|
||||
if not max_budget:
|
||||
# if max_budget is None, then we've already migrated the User
|
||||
return None
|
||||
credits = max(max_budget - spend, 0.0)
|
||||
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
await LiteLlmManager._update_user(
|
||||
client, keycloak_user_id, max_budget=UNLIMITED_BUDGET_SETTING
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key:
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key,
|
||||
team_id=org_id,
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key_for_byor,
|
||||
team_id=org_id,
|
||||
)
|
||||
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
async def update_team_and_users_budget(
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
await LiteLlmManager._update_team(client, team_id, None, max_budget)
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
return None
|
||||
# TODO: change to use bulk update endpoint
|
||||
for membership in team_info.get('team_memberships', []):
|
||||
user_id = membership.get('user_id')
|
||||
if not user_id:
|
||||
continue
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
client, user_id, team_id, max_budget
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _create_team(
|
||||
client: httpx.AsyncClient,
|
||||
team_alias: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/new',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'team_alias': team_alias,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': 0,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
# Team failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
response.status_code == 400
|
||||
and 'already exists. Please use a different team id' in response.text
|
||||
):
|
||||
# team already exists, so update, then return
|
||||
await LiteLlmManager._update_team(
|
||||
client, team_id, team_alias, max_budget
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_creating_litellm_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'team_id': team_id,
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_team(client: httpx.AsyncClient, team_id: str) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
"""Get a team from litellm with the id matching that given."""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/team/info?team_id={team_id}',
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
async def _update_team(
|
||||
client: httpx.AsyncClient,
|
||||
team_id: str,
|
||||
team_alias: str | None,
|
||||
max_budget: float | None,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget'] = max_budget
|
||||
|
||||
if team_alias is not None:
|
||||
json_data['team_alias'] = team_alias
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/update',
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
# Team failed to update in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _create_user(
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': email,
|
||||
'models': [],
|
||||
'user_id': keycloak_user_id,
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': False,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
'duplicate_user_email',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'email': email,
|
||||
},
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': None,
|
||||
'models': [],
|
||||
'user_id': keycloak_user_id,
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': False,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
response.status_code in (400, 409)
|
||||
and 'already exists' in response.text
|
||||
):
|
||||
logger.warning(
|
||||
'litellm_user_already_exists',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'email': None,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_user(client: httpx.AsyncClient, user_id: str) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
"""Get a user from litellm with the id matching that given."""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
async def _update_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
**kwargs,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
payload = {
|
||||
'user_id': keycloak_user_id,
|
||||
}
|
||||
payload.update(kwargs)
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _update_key(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
key: str,
|
||||
**kwargs,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
payload = {
|
||||
'key': key,
|
||||
}
|
||||
payload.update(kwargs)
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/update',
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
if response.status_code == 401:
|
||||
logger.warning(
|
||||
'invalid_litellm_key_during_update',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_updating_litellm_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _delete_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [keycloak_user_id]}
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_deleting_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _delete_team(
|
||||
client: httpx.AsyncClient,
|
||||
team_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/delete',
|
||||
json={'team_ids': [team_id]},
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
if response.status_code == 404:
|
||||
# Team doesn't exist, that's fine
|
||||
logger.info(
|
||||
'Team already deleted or does not exist',
|
||||
extra={'team_id': team_id},
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_deleting_litellm_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'team_id': team_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
'LiteLlmManager:_delete_team:team_deleted',
|
||||
extra={'team_id': team_id},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _add_user_to_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_add',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'member': {'user_id': keycloak_user_id, 'role': 'user'},
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
# Failed to add user to team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
response.status_code == 400
|
||||
and 'already in team' in response.text.lower()
|
||||
):
|
||||
logger.warning(
|
||||
'user_already_in_team',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
},
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_adding_litellm_user_to_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_user_team_info(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
return None
|
||||
|
||||
# Filter team_memberships based on team_id and keycloak_user_id
|
||||
user_membership = next(
|
||||
(
|
||||
membership
|
||||
for membership in team_info.get('team_memberships', [])
|
||||
if membership.get('user_id') == keycloak_user_id
|
||||
and membership.get('team_id') == team_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return user_membership
|
||||
|
||||
@staticmethod
|
||||
async def _update_user_in_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_update',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
# Failed to update user in team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_user_in_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _generate_key(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str | None,
|
||||
key_alias: str | None,
|
||||
metadata: dict | None,
|
||||
) -> str | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
json_data: dict[str, Any] = {
|
||||
'user_id': keycloak_user_id,
|
||||
'models': [],
|
||||
}
|
||||
|
||||
if team_id is not None:
|
||||
json_data['team_id'] = team_id
|
||||
|
||||
if key_alias is not None:
|
||||
json_data['key_alias'] = key_alias
|
||||
|
||||
if metadata is not None:
|
||||
json_data['metadata'] = metadata
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json=json_data,
|
||||
)
|
||||
# Failed to generate user key for team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_generate_user_team_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
'key_alias': key_alias,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json['key']
|
||||
logger.info(
|
||||
'LiteLlmManager:_lite_llm_generate_user_team_key:key_created',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
'key_alias': key_alias,
|
||||
},
|
||||
)
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
async def verify_key(key: str, user_id: str) -> bool:
|
||||
"""Verify that a key is valid in LiteLLM by making a lightweight API call.
|
||||
|
||||
Args:
|
||||
key: The key to verify
|
||||
user_id: The user ID for logging purposes
|
||||
|
||||
Returns:
|
||||
True if the key is verified as valid, False if verification fails or key is invalid.
|
||||
Returns False on network errors/timeouts to ensure we don't return potentially invalid keys.
|
||||
"""
|
||||
if not (LITE_LLM_API_URL and key):
|
||||
return False
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
) as client:
|
||||
# Make a lightweight request to verify the key
|
||||
# Using /v1/models endpoint as it's lightweight and requires authentication
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/v1/models',
|
||||
headers={
|
||||
'Authorization': f'Bearer {key}',
|
||||
},
|
||||
)
|
||||
|
||||
# Only 200 status code indicates valid key
|
||||
if response.status_code == 200:
|
||||
logger.debug(
|
||||
'BYOR key verification successful',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
|
||||
# All other status codes (401, 403, 500, etc.) are treated as invalid
|
||||
# This includes authentication errors and server errors
|
||||
logger.warning(
|
||||
'BYOR key verification failed - treating as invalid',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': response.status_code,
|
||||
'key_prefix': key[:10] + '...' if len(key) > 10 else key,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
except (httpx.TimeoutException, Exception) as e:
|
||||
# Any exception (timeout, network error, etc.) means we can't verify
|
||||
# Return False to trigger regeneration rather than returning potentially invalid key
|
||||
logger.warning(
|
||||
'BYOR key verification error - treating as invalid to ensure key validity',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
'error_type': type(e).__name__,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _get_key_info(
|
||||
client: httpx.AsyncClient,
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
) -> dict | None:
|
||||
from storage.user_store import UserStore
|
||||
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
user = await UserStore.get_user_by_id_async(keycloak_user_id)
|
||||
if not user:
|
||||
return {}
|
||||
|
||||
org_member = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return {}
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/key/info?key={org_member.llm_api_key}'
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key_info = response_json.get('info')
|
||||
if not key_info:
|
||||
return {}
|
||||
return {
|
||||
'key_max_budget': key_info.get('max_budget'),
|
||||
'key_spend': key_info.get('spend'),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _delete_key_by_alias(
|
||||
client: httpx.AsyncClient,
|
||||
key_alias: str,
|
||||
):
|
||||
"""Delete a key from LiteLLM by its alias.
|
||||
|
||||
This is a best-effort operation that logs but does not raise on failure.
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/delete',
|
||||
json={
|
||||
'key_aliases': [key_alias],
|
||||
},
|
||||
)
|
||||
if response.is_success:
|
||||
logger.info(
|
||||
'LiteLlmManager:_delete_key_by_alias:key_deleted',
|
||||
extra={'key_alias': key_alias},
|
||||
)
|
||||
elif response.status_code != 404:
|
||||
# Log non-404 errors but don't fail
|
||||
logger.warning(
|
||||
'error_deleting_key_by_alias',
|
||||
extra={
|
||||
'key_alias': key_alias,
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _delete_key(
|
||||
client: httpx.AsyncClient,
|
||||
key_id: str,
|
||||
key_alias: str | None = None,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/delete',
|
||||
json={
|
||||
'keys': [key_id],
|
||||
},
|
||||
)
|
||||
# Failed to delete key...
|
||||
if not response.is_success:
|
||||
if response.status_code == 404:
|
||||
# Key doesn't exist by key_id. If we have a key_alias,
|
||||
# try deleting by alias to clean up any orphaned alias.
|
||||
if key_alias:
|
||||
await LiteLlmManager._delete_key_by_alias(client, key_alias)
|
||||
return
|
||||
logger.error(
|
||||
'error_deleting_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
'LiteLlmManager:_delete_key:key_deleted',
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def with_http_client(
|
||||
internal_fn: Callable[..., Awaitable[Any]],
|
||||
) -> Callable[..., Awaitable[Any]]:
|
||||
@functools.wraps(internal_fn)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with httpx.AsyncClient(
|
||||
headers={'x-goog-api-key': LITE_LLM_API_KEY}
|
||||
) as client:
|
||||
return await internal_fn(client, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
# Public methods with injected client
|
||||
create_team = staticmethod(with_http_client(_create_team))
|
||||
get_team = staticmethod(with_http_client(_get_team))
|
||||
update_team = staticmethod(with_http_client(_update_team))
|
||||
create_user = staticmethod(with_http_client(_create_user))
|
||||
get_user = staticmethod(with_http_client(_get_user))
|
||||
update_user = staticmethod(with_http_client(_update_user))
|
||||
delete_user = staticmethod(with_http_client(_delete_user))
|
||||
delete_team = staticmethod(with_http_client(_delete_team))
|
||||
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
|
||||
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
|
||||
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
|
||||
generate_key = staticmethod(with_http_client(_generate_key))
|
||||
get_key_info = staticmethod(with_http_client(_get_key_info))
|
||||
delete_key = staticmethod(with_http_client(_delete_key))
|
||||
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import SecretStr
|
||||
from server.constants import DEFAULT_BILLING_MARGIN
|
||||
from sqlalchemy import JSON, UUID, Boolean, Column, Float, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class Org(Base): # type: ignore
|
||||
"""Organization model."""
|
||||
|
||||
__tablename__ = 'org'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
contact_name = Column(String, nullable=True)
|
||||
contact_email = Column(String, nullable=True)
|
||||
agent = Column(String, nullable=True)
|
||||
default_max_iterations = Column(Integer, nullable=True)
|
||||
security_analyzer = Column(String, nullable=True)
|
||||
confirmation_mode = Column(Boolean, nullable=True, default=False)
|
||||
default_llm_model = Column(String, nullable=True)
|
||||
default_llm_base_url = Column(String, nullable=True)
|
||||
remote_runtime_resource_factor = Column(Integer, nullable=True)
|
||||
enable_default_condenser = Column(Boolean, nullable=False, default=True)
|
||||
billing_margin = Column(Float, nullable=True, default=DEFAULT_BILLING_MARGIN)
|
||||
enable_proactive_conversation_starters = Column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
sandbox_base_container_image = Column(String, nullable=True)
|
||||
sandbox_runtime_container_image = Column(String, nullable=True)
|
||||
org_version = Column(Integer, nullable=False, default=0)
|
||||
mcp_config = Column(JSON, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_search_api_key = Column(String, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_sandbox_api_key = Column(String, nullable=True)
|
||||
max_budget_per_task = Column(Float, nullable=True)
|
||||
enable_solvability_analysis = Column(Boolean, nullable=True, default=False)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
conversation_expiration = Column(Integer, nullable=True)
|
||||
condenser_max_size = Column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org_members = relationship('OrgMember', back_populates='org')
|
||||
current_users = relationship('User', back_populates='current_org')
|
||||
billing_sessions = relationship('BillingSession', back_populates='org')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='org'
|
||||
)
|
||||
user_secrets = relationship('StoredCustomSecrets', back_populates='org')
|
||||
api_keys = relationship('ApiKey', back_populates='org')
|
||||
slack_conversations = relationship('SlackConversation', back_populates='org')
|
||||
slack_users = relationship('SlackUser', back_populates='org')
|
||||
stripe_customers = relationship('StripeCustomer', back_populates='org')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'search_api_key' in kwargs:
|
||||
self.search_api_key = kwargs.pop('search_api_key')
|
||||
if 'sandbox_api_key' in kwargs:
|
||||
self.sandbox_api_key = kwargs.pop('sandbox_api_key')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def search_api_key(self) -> SecretStr | None:
|
||||
if self._search_api_key:
|
||||
decrypted = decrypt_value(self._search_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@search_api_key.setter
|
||||
def search_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._search_api_key = encrypt_value(raw) if raw else None
|
||||
|
||||
@property
|
||||
def sandbox_api_key(self) -> SecretStr | None:
|
||||
if self._sandbox_api_key:
|
||||
decrypted = decrypt_value(self._sandbox_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@sandbox_api_key.setter
|
||||
def sandbox_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._sandbox_api_key = encrypt_value(raw) if raw else None
|
||||
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization-Member relationship.
|
||||
"""
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import UUID, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class OrgMember(Base): # type: ignore
|
||||
"""Junction table for organization-member relationships with roles."""
|
||||
|
||||
__tablename__ = 'org_member'
|
||||
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), primary_key=True)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), primary_key=True)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
|
||||
_llm_api_key = Column(String, nullable=False)
|
||||
max_iterations = Column(Integer, nullable=True)
|
||||
llm_model = Column(String, nullable=True)
|
||||
_llm_api_key_for_byor = Column(String, nullable=True)
|
||||
llm_base_url = Column(String, nullable=True)
|
||||
status = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='org_members')
|
||||
user = relationship('User', back_populates='org_members')
|
||||
role = relationship('Role', back_populates='org_members')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'llm_api_key' in kwargs:
|
||||
self.llm_api_key = kwargs.pop('llm_api_key')
|
||||
if 'llm_api_key_for_byor' in kwargs:
|
||||
self.llm_api_key_for_byor = kwargs.pop('llm_api_key_for_byor')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def llm_api_key(self) -> SecretStr:
|
||||
decrypted = decrypt_value(self._llm_api_key)
|
||||
return SecretStr(decrypted)
|
||||
|
||||
@llm_api_key.setter
|
||||
def llm_api_key(self, value: str | SecretStr):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key = encrypt_value(raw)
|
||||
|
||||
@property
|
||||
def llm_api_key_for_byor(self) -> SecretStr | None:
|
||||
if self._llm_api_key_for_byor:
|
||||
decrypted = decrypt_value(self._llm_api_key_for_byor)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@llm_api_key_for_byor.setter
|
||||
def llm_api_key_for_byor(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key_for_byor = encrypt_value(raw) if raw else None
|
||||
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Store class for managing organization-member relationships.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from storage.database import session_maker
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgMemberStore:
|
||||
"""Store for managing organization-member relationships."""
|
||||
|
||||
@staticmethod
|
||||
def add_user_to_org(
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
role_id: int,
|
||||
llm_api_key: str,
|
||||
status: Optional[str] = None,
|
||||
) -> OrgMember:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
with session_maker() as session:
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
llm_api_key=llm_api_key,
|
||||
status=status,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: int) -> list[OrgMember]:
|
||||
"""Get all organizations for a user."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.user_id == user_id).all()
|
||||
|
||||
@staticmethod
|
||||
def get_org_members(org_id: UUID) -> list[OrgMember]:
|
||||
"""Get all users in an organization."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.org_id == org_id).all()
|
||||
|
||||
@staticmethod
|
||||
def update_org_member(org_member: OrgMember) -> None:
|
||||
"""Update an organization-member relationship."""
|
||||
with session_maker() as session:
|
||||
session.merge(org_member)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_user_role_in_org(
|
||||
org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None
|
||||
) -> Optional[OrgMember]:
|
||||
"""Update user's role in an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
org_member.role_id = role_id
|
||||
if status is not None:
|
||||
org_member.status = status
|
||||
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def remove_user_from_org(org_id: UUID, user_id: int) -> bool:
|
||||
"""Remove a user from an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return False
|
||||
|
||||
session.delete(org_member)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {
|
||||
normalized: getattr(settings, normalized)
|
||||
for c in OrgMember.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {
|
||||
normalized: getattr(user_settings, normalized)
|
||||
for c in OrgMember.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Service class for managing organization operations.
|
||||
Separates business logic from route handlers.
|
||||
"""
|
||||
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_store import OrgStore
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class OrgService:
|
||||
"""Service for handling organization-related operations."""
|
||||
|
||||
@staticmethod
|
||||
def validate_name_uniqueness(name: str) -> None:
|
||||
"""
|
||||
Validate that organization name is unique.
|
||||
|
||||
Args:
|
||||
name: Organization name to validate
|
||||
|
||||
Raises:
|
||||
OrgNameExistsError: If organization name already exists
|
||||
"""
|
||||
existing_org = OrgStore.get_org_by_name(name)
|
||||
if existing_org is not None:
|
||||
raise OrgNameExistsError(name)
|
||||
|
||||
@staticmethod
|
||||
async def create_litellm_integration(org_id: UUID, user_id: str) -> dict:
|
||||
"""
|
||||
Create LiteLLM team integration for the organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
user_id: User ID who will own the organization
|
||||
|
||||
Returns:
|
||||
dict: LiteLLM settings object
|
||||
|
||||
Raises:
|
||||
LiteLLMIntegrationError: If LiteLLM integration fails
|
||||
"""
|
||||
try:
|
||||
settings = await UserStore.create_default_settings(
|
||||
org_id=str(org_id), user_id=user_id, create_user=False
|
||||
)
|
||||
|
||||
if not settings:
|
||||
logger.error(
|
||||
'Failed to create LiteLLM settings',
|
||||
extra={'org_id': str(org_id), 'user_id': user_id},
|
||||
)
|
||||
raise LiteLLMIntegrationError('Failed to create LiteLLM settings')
|
||||
|
||||
logger.debug(
|
||||
'LiteLLM integration created',
|
||||
extra={'org_id': str(org_id), 'user_id': user_id},
|
||||
)
|
||||
return settings
|
||||
|
||||
except LiteLLMIntegrationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error creating LiteLLM integration',
|
||||
extra={'org_id': str(org_id), 'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise LiteLLMIntegrationError(f'LiteLLM integration failed: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
def create_org_entity(
|
||||
org_id: UUID,
|
||||
name: str,
|
||||
contact_name: str,
|
||||
contact_email: str,
|
||||
) -> Org:
|
||||
"""
|
||||
Create an organization entity with basic information.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
name: Organization name
|
||||
contact_name: Contact person name
|
||||
contact_email: Contact email address
|
||||
|
||||
Returns:
|
||||
Org: New organization entity (not yet persisted)
|
||||
"""
|
||||
return Org(
|
||||
id=org_id,
|
||||
name=name,
|
||||
contact_name=contact_name,
|
||||
contact_email=contact_email,
|
||||
org_version=ORG_SETTINGS_VERSION,
|
||||
default_llm_model=get_default_litellm_model(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def apply_litellm_settings_to_org(org: Org, settings: dict) -> None:
|
||||
"""
|
||||
Apply LiteLLM settings to organization entity.
|
||||
|
||||
Args:
|
||||
org: Organization entity to update
|
||||
settings: LiteLLM settings object
|
||||
"""
|
||||
org_kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
@staticmethod
|
||||
def get_owner_role():
|
||||
"""
|
||||
Get the owner role from the database.
|
||||
|
||||
Returns:
|
||||
Role: The owner role object
|
||||
|
||||
Raises:
|
||||
Exception: If owner role not found
|
||||
"""
|
||||
owner_role = RoleStore.get_role_by_name('owner')
|
||||
if not owner_role:
|
||||
raise Exception('Owner role not found in database')
|
||||
return owner_role
|
||||
|
||||
@staticmethod
|
||||
def create_org_member_entity(
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
role_id: int,
|
||||
settings: dict,
|
||||
) -> OrgMember:
|
||||
"""
|
||||
Create an organization member entity.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
role_id: Role ID
|
||||
settings: LiteLLM settings object
|
||||
|
||||
Returns:
|
||||
OrgMember: New organization member entity (not yet persisted)
|
||||
"""
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_settings(settings)
|
||||
return OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=parse_uuid(user_id),
|
||||
role_id=role_id,
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_org_with_owner(
|
||||
name: str,
|
||||
contact_name: str,
|
||||
contact_email: str,
|
||||
user_id: str,
|
||||
) -> Org:
|
||||
"""
|
||||
Create a new organization with the specified user as owner.
|
||||
|
||||
This method orchestrates the complete organization creation workflow:
|
||||
1. Validates that the organization name doesn't already exist
|
||||
2. Generates a unique organization ID
|
||||
3. Creates LiteLLM team integration
|
||||
4. Creates the organization entity
|
||||
5. Applies LiteLLM settings
|
||||
6. Creates owner membership
|
||||
7. Persists everything in a transaction
|
||||
|
||||
If database persistence fails, LiteLLM resources are cleaned up (compensation).
|
||||
|
||||
Args:
|
||||
name: Organization name (must be unique)
|
||||
contact_name: Contact person name
|
||||
contact_email: Contact email address
|
||||
user_id: ID of the user who will be the owner
|
||||
|
||||
Returns:
|
||||
Org: The created organization object
|
||||
|
||||
Raises:
|
||||
OrgNameExistsError: If organization name already exists
|
||||
LiteLLMIntegrationError: If LiteLLM integration fails
|
||||
OrgDatabaseError: If database operations fail
|
||||
"""
|
||||
logger.info(
|
||||
'Starting organization creation',
|
||||
extra={'user_id': user_id, 'org_name': name},
|
||||
)
|
||||
|
||||
# Step 1: Validate name uniqueness (fails early, no cleanup needed)
|
||||
OrgService.validate_name_uniqueness(name)
|
||||
|
||||
# Step 2: Generate organization ID
|
||||
org_id = uuid4()
|
||||
|
||||
# Step 3: Create LiteLLM integration (external state created)
|
||||
settings = await OrgService.create_litellm_integration(org_id, user_id)
|
||||
|
||||
# Steps 4-7: Create entities and persist with compensation
|
||||
# If any of these fail, we need to clean up LiteLLM resources
|
||||
try:
|
||||
# Step 4: Create organization entity
|
||||
org = OrgService.create_org_entity(
|
||||
org_id=org_id,
|
||||
name=name,
|
||||
contact_name=contact_name,
|
||||
contact_email=contact_email,
|
||||
)
|
||||
|
||||
# Step 5: Apply LiteLLM settings
|
||||
OrgService.apply_litellm_settings_to_org(org, settings)
|
||||
|
||||
# Step 6: Get owner role and create member entity
|
||||
owner_role = OrgService.get_owner_role()
|
||||
org_member = OrgService.create_org_member_entity(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=owner_role.id,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
# Step 7: Persist in transaction (critical section)
|
||||
persisted_org = await OrgService._persist_with_compensation(
|
||||
org, org_member, org_id, user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Successfully created organization',
|
||||
extra={
|
||||
'org_id': str(persisted_org.id),
|
||||
'org_name': persisted_org.name,
|
||||
'user_id': user_id,
|
||||
'role': 'owner',
|
||||
},
|
||||
)
|
||||
|
||||
return persisted_org
|
||||
|
||||
except OrgDatabaseError:
|
||||
# Already handled by _persist_with_compensation, just re-raise
|
||||
raise
|
||||
except Exception as e:
|
||||
# Unexpected error in steps 4-6, need to clean up LiteLLM
|
||||
logger.error(
|
||||
'Unexpected error during organization creation, initiating cleanup',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
await OrgService._handle_failure_with_cleanup(
|
||||
org_id, user_id, e, 'Failed to create organization'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _persist_with_compensation(
|
||||
org: Org,
|
||||
org_member: OrgMember,
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
) -> Org:
|
||||
"""
|
||||
Persist organization with compensation on failure.
|
||||
|
||||
If database persistence fails, cleans up LiteLLM resources.
|
||||
|
||||
Args:
|
||||
org: Organization entity to persist
|
||||
org_member: Organization member entity to persist
|
||||
org_id: Organization ID (for cleanup)
|
||||
user_id: User ID (for cleanup)
|
||||
|
||||
Returns:
|
||||
Org: The persisted organization object
|
||||
|
||||
Raises:
|
||||
OrgDatabaseError: If database operations fail
|
||||
"""
|
||||
try:
|
||||
persisted_org = OrgStore.persist_org_with_owner(org, org_member)
|
||||
return persisted_org
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Database persistence failed, initiating LiteLLM cleanup',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
await OrgService._handle_failure_with_cleanup(
|
||||
org_id, user_id, e, 'Failed to create organization'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _handle_failure_with_cleanup(
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
original_error: Exception,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
"""
|
||||
Handle failure by cleaning up LiteLLM resources and raising appropriate error.
|
||||
|
||||
This method performs compensating transaction and raises OrgDatabaseError.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
user_id: User ID
|
||||
original_error: The original exception that caused the failure
|
||||
error_message: Base error message for the exception
|
||||
|
||||
Raises:
|
||||
OrgDatabaseError: Always raises with details about the failure
|
||||
"""
|
||||
cleanup_error = await OrgService._cleanup_litellm_resources(org_id, user_id)
|
||||
|
||||
if cleanup_error:
|
||||
logger.error(
|
||||
'Both operation and cleanup failed',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'original_error': str(original_error),
|
||||
'cleanup_error': str(cleanup_error),
|
||||
},
|
||||
)
|
||||
raise OrgDatabaseError(
|
||||
f'{error_message}: {str(original_error)}. '
|
||||
f'Cleanup also failed: {str(cleanup_error)}'
|
||||
)
|
||||
|
||||
raise OrgDatabaseError(f'{error_message}: {str(original_error)}')
|
||||
|
||||
@staticmethod
|
||||
async def _cleanup_litellm_resources(
|
||||
org_id: UUID, user_id: str
|
||||
) -> Exception | None:
|
||||
"""
|
||||
Compensating transaction: Clean up LiteLLM resources.
|
||||
|
||||
Deletes the team which should cascade to remove keys and memberships.
|
||||
This is a best-effort operation - errors are logged but not raised.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Exception | None: Exception if cleanup failed, None if successful
|
||||
"""
|
||||
try:
|
||||
await LiteLlmManager.delete_team(str(org_id))
|
||||
|
||||
logger.info(
|
||||
'Successfully cleaned up LiteLLM team',
|
||||
extra={'org_id': str(org_id), 'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to cleanup LiteLLM team (resources may be orphaned)',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return e
|
||||
|
||||
@staticmethod
|
||||
async def get_org_credits(user_id: str, org_id: UUID) -> float | None:
|
||||
"""
|
||||
Get organization credits from LiteLLM team.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
org_id: Organization ID
|
||||
|
||||
Returns:
|
||||
float | None: Credits (max_budget - spend) or None if LiteLLM not configured
|
||||
"""
|
||||
try:
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(org_id)
|
||||
)
|
||||
if not user_team_info:
|
||||
logger.warning(
|
||||
'No team info available from LiteLLM',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
return None
|
||||
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
spend = user_team_info.get('spend', 0)
|
||||
credits = max(max_budget - spend, 0)
|
||||
|
||||
logger.debug(
|
||||
'Retrieved organization credits',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'credits': credits,
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
|
||||
return credits
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'Failed to retrieve organization credits',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Store class for managing organizations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgStore:
|
||||
"""Store for managing organizations."""
|
||||
|
||||
@staticmethod
|
||||
def create_org(
|
||||
kwargs: dict,
|
||||
) -> Org:
|
||||
"""Create a new organization."""
|
||||
with session_maker() as session:
|
||||
org = Org(**kwargs)
|
||||
org.org_version = ORG_SETTINGS_VERSION
|
||||
org.default_llm_model = get_default_litellm_model()
|
||||
session.add(org)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_id(org_id: UUID) -> Org | None:
|
||||
"""Get organization by ID."""
|
||||
org = None
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
return OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def get_current_org_from_keycloak_user_id(keycloak_user_id: str) -> Org | None:
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == UUID(keycloak_user_id))
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
logger.warning(f'User not found for ID {keycloak_user_id}')
|
||||
return None
|
||||
org_id = user.current_org_id
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
f'Org not found for ID {org_id} as the current org for user {keycloak_user_id}'
|
||||
)
|
||||
return None
|
||||
return OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_name(name: str) -> Org | None:
|
||||
"""Get organization by name."""
|
||||
org = None
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.name == name).first()
|
||||
return OrgStore._validate_org_version(org)
|
||||
|
||||
@staticmethod
|
||||
def _validate_org_version(org: Org) -> Org | None:
|
||||
"""Check if we need to update org version."""
|
||||
if org and org.org_version < ORG_SETTINGS_VERSION:
|
||||
org = OrgStore.update_org(
|
||||
org.id,
|
||||
{
|
||||
'org_version': ORG_SETTINGS_VERSION,
|
||||
'default_llm_model': get_default_litellm_model(),
|
||||
'llm_base_url': LITE_LLM_API_URL,
|
||||
},
|
||||
)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def list_orgs() -> list[Org]:
|
||||
"""List all organizations."""
|
||||
with session_maker() as session:
|
||||
orgs = session.query(Org).all()
|
||||
return orgs
|
||||
|
||||
@staticmethod
|
||||
def update_org(
|
||||
org_id: UUID,
|
||||
kwargs: dict,
|
||||
) -> Optional[Org]:
|
||||
"""Update organization details."""
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
if 'id' in kwargs:
|
||||
kwargs.pop('id')
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {}
|
||||
|
||||
for c in Org.__table__.columns:
|
||||
# Normalize for lookup
|
||||
normalized = (
|
||||
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
|
||||
)
|
||||
|
||||
if not hasattr(settings, normalized):
|
||||
continue
|
||||
|
||||
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
|
||||
key = c.name
|
||||
if key.startswith('_'):
|
||||
key = key[1:] # remove only the very first leading underscore
|
||||
|
||||
kwargs[key] = getattr(settings, normalized)
|
||||
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {}
|
||||
|
||||
for c in Org.__table__.columns:
|
||||
# Normalize for lookup
|
||||
normalized = (
|
||||
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
|
||||
)
|
||||
|
||||
if not hasattr(user_settings, normalized):
|
||||
continue
|
||||
|
||||
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
|
||||
key = c.name
|
||||
if key.startswith('_'):
|
||||
key = key[1:] # remove only the very first leading underscore
|
||||
|
||||
kwargs[key] = getattr(user_settings, normalized)
|
||||
|
||||
kwargs['org_version'] = user_settings.user_version
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def persist_org_with_owner(
|
||||
org: Org,
|
||||
org_member: OrgMember,
|
||||
) -> Org:
|
||||
"""
|
||||
Persist organization and owner membership in a single transaction.
|
||||
|
||||
Args:
|
||||
org: Organization entity to persist
|
||||
org_member: Organization member entity to persist
|
||||
|
||||
Returns:
|
||||
Org: The persisted organization object
|
||||
|
||||
Raises:
|
||||
Exception: If database operations fail
|
||||
"""
|
||||
with session_maker() as session:
|
||||
session.add(org)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
SQLAlchemy model for Role.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class Role(Base): # type: ignore
|
||||
"""Role model for user permissions."""
|
||||
|
||||
__tablename__ = 'role'
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
rank = Column(Integer, nullable=False)
|
||||
|
||||
# Relationships
|
||||
users = relationship('User', back_populates='role')
|
||||
org_members = relationship('OrgMember', back_populates='role')
|
||||
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Store class for managing roles.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from storage.database import session_maker
|
||||
from storage.role import Role
|
||||
|
||||
|
||||
class RoleStore:
|
||||
"""Store for managing roles."""
|
||||
|
||||
@staticmethod
|
||||
def create_role(name: str, rank: int) -> Role:
|
||||
"""Create a new role."""
|
||||
with session_maker() as session:
|
||||
role = Role(name=name, rank=rank)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
session.refresh(role)
|
||||
return role
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_id(role_id: int) -> Optional[Role]:
|
||||
"""Get role by ID."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.id == role_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_name(name: str) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
def list_roles() -> List[Role]:
|
||||
"""List all roles."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).order_by(Role.rank).all()
|
||||
@@ -4,10 +4,13 @@ import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
@@ -29,20 +32,37 @@ logger = logging.getLogger(__name__)
|
||||
class SaasConversationStore(ConversationStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
org_id: UUID | None = None # will be fetched automatically
|
||||
|
||||
def __init__(self, user_id: str, session_maker: sessionmaker):
|
||||
self.user_id = user_id
|
||||
self.session_maker = session_maker
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
self.org_id = user.current_org_id if user else None
|
||||
|
||||
def _select_by_id(self, session, conversation_id: str):
|
||||
return (
|
||||
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
|
||||
query = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.user_id == UUID(self.user_id))
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
)
|
||||
|
||||
if self.org_id is not None:
|
||||
query = query.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
|
||||
return query
|
||||
|
||||
def _to_external_model(self, conversation_metadata: StoredConversationMetadata):
|
||||
kwargs = {
|
||||
c.name: getattr(conversation_metadata, c.name)
|
||||
for c in StoredConversationMetadata.__table__.columns
|
||||
if c.name != 'github_user_id' # Skip github_user_id field
|
||||
}
|
||||
# TODO: I'm not sure why the timezone is not set on the dates coming back out of the db
|
||||
kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC)
|
||||
@@ -53,6 +73,8 @@ class SaasConversationStore(ConversationStore):
|
||||
# Convert string to ProviderType enum
|
||||
kwargs['git_provider'] = ProviderType(kwargs['git_provider'])
|
||||
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Remove V1 attributes
|
||||
kwargs.pop('max_budget_per_task', None)
|
||||
kwargs.pop('cache_read_tokens', None)
|
||||
@@ -67,7 +89,10 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
async def save_metadata(self, metadata: ConversationMetadata):
|
||||
kwargs = dataclasses.asdict(metadata)
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata
|
||||
kwargs.pop('user_id', None)
|
||||
kwargs.pop('org_id', None)
|
||||
|
||||
# Convert ProviderType enum to string for storage
|
||||
if kwargs.get('git_provider') is not None:
|
||||
@@ -81,7 +106,41 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
def _save_metadata():
|
||||
with self.session_maker() as session:
|
||||
# Save the main conversation metadata
|
||||
session.merge(stored_metadata)
|
||||
|
||||
# Create or update the SaaS metadata record
|
||||
saas_metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== stored_metadata.conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not saas_metadata:
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=stored_metadata.conversation_id,
|
||||
user_id=UUID(self.user_id),
|
||||
org_id=self.org_id,
|
||||
)
|
||||
session.add(saas_metadata)
|
||||
else:
|
||||
# Validate
|
||||
expected_user_id = UUID(self.user_id)
|
||||
expected_org_id = self.org_id
|
||||
|
||||
if saas_metadata.user_id != expected_user_id:
|
||||
raise ValueError(
|
||||
f'Existing user_id ({saas_metadata.user_id}) does not match expected value ({expected_user_id}).'
|
||||
)
|
||||
|
||||
if expected_org_id and saas_metadata.org_id != expected_org_id:
|
||||
raise ValueError(
|
||||
f'Existing org_id ({saas_metadata.org_id}) does not match expected value ({expected_org_id}).'
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_metadata)
|
||||
@@ -101,8 +160,29 @@ class SaasConversationStore(ConversationStore):
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
def _delete_metadata():
|
||||
with self.session_maker() as session:
|
||||
self._select_by_id(session, conversation_id).delete()
|
||||
session.commit()
|
||||
saas_record = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id),
|
||||
StoredConversationMetadataSaas.org_id == self.org_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if saas_record:
|
||||
# Delete both records, but only if the SaaS one exists
|
||||
session.query(StoredConversationMetadata).filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
).delete()
|
||||
|
||||
session.delete(saas_record)
|
||||
|
||||
session.commit()
|
||||
else:
|
||||
# No SaaS record found → skip deleting main metadata
|
||||
session.rollback()
|
||||
|
||||
await call_sync_from_async(_delete_metadata)
|
||||
|
||||
@@ -125,7 +205,15 @@ class SaasConversationStore(ConversationStore):
|
||||
with self.session_maker() as session:
|
||||
conversations = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id)
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
.order_by(StoredConversationMetadata.created_at.desc())
|
||||
.offset(offset)
|
||||
|
||||
@@ -8,6 +8,7 @@ from cryptography.fernet import Fernet
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -24,14 +25,17 @@ class SaasSecretsStore(SecretsStore):
|
||||
async def load(self) -> Secrets | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Fetch all secrets for the given user ID
|
||||
settings = (
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
query = session.query(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
if org_id is not None:
|
||||
query = query.filter(StoredCustomSecrets.org_id == org_id)
|
||||
settings = query.all()
|
||||
|
||||
if not settings:
|
||||
return Secrets()
|
||||
@@ -48,6 +52,8 @@ class SaasSecretsStore(SecretsStore):
|
||||
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
|
||||
|
||||
async def store(self, item: Secrets):
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete all existing records and override with incoming ones
|
||||
@@ -76,6 +82,7 @@ class SaasSecretsStore(SecretsStore):
|
||||
for secret_name, secret_value, description in secret_tuples:
|
||||
new_secret = StoredCustomSecrets(
|
||||
keycloak_user_id=self.user_id,
|
||||
org_id=org_id,
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
description=description,
|
||||
|
||||
@@ -1,46 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import binascii
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from base64 import b64decode, b64encode
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
from cryptography.fernet import Fernet
|
||||
from integrations import stripe_service
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
CURRENT_USER_SETTINGS_VERSION,
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
REQUIRE_PAYMENT,
|
||||
USER_SETTINGS_VERSION_TO_MODEL,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import joinedload, sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# The max possible time to wait for another process to finish creating a user before retrying
|
||||
_REDIS_CREATE_TIMEOUT_SECONDS = 30
|
||||
# The delay to wait for another process to finish creating a user before trying to load again
|
||||
_RETRY_LOAD_DELAY_SECONDS = 2
|
||||
# Redis key prefix for user creation locks
|
||||
_REDIS_USER_CREATION_KEY_PREFIX = 'create_user:'
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -48,8 +30,9 @@ class SaasSettingsStore(SettingsStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
|
||||
|
||||
def get_user_settings_by_keycloak_id(
|
||||
def _get_user_settings_by_keycloak_id(
|
||||
self, keycloak_user_id: str, session=None
|
||||
) -> UserSettings | None:
|
||||
"""
|
||||
@@ -85,354 +68,105 @@ class SaasSettingsStore(SettingsStore):
|
||||
return _get_settings()
|
||||
|
||||
async def load(self) -> Settings | None:
|
||||
if not self.user_id:
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
if not user:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
return None
|
||||
with self.session_maker() as session:
|
||||
settings = self.get_user_settings_by_keycloak_id(self.user_id, session)
|
||||
|
||||
if not settings or settings.user_version != CURRENT_USER_SETTINGS_VERSION:
|
||||
logger.info(
|
||||
'saas_settings_store:load:triggering_migration',
|
||||
extra={'user_id': self.user_id},
|
||||
org_id = user.current_org_id
|
||||
org_member: OrgMember = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
kwargs = {
|
||||
**{
|
||||
normalized: getattr(org, c.name)
|
||||
for c in Org.__table__.columns
|
||||
if (
|
||||
normalized := c.name.removeprefix('_default_')
|
||||
.removeprefix('default_')
|
||||
.lstrip('_')
|
||||
)
|
||||
return await self.create_default_settings(settings)
|
||||
kwargs = {
|
||||
c.name: getattr(settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
in Settings.model_fields
|
||||
},
|
||||
**{
|
||||
normalized: getattr(user, c.name)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) in Settings.model_fields
|
||||
},
|
||||
}
|
||||
kwargs['llm_api_key'] = org_member.llm_api_key
|
||||
if org_member.max_iterations:
|
||||
kwargs['max_iterations'] = org_member.max_iterations
|
||||
if org_member.llm_model:
|
||||
kwargs['llm_model'] = org_member.llm_model
|
||||
if org_member.llm_api_key_for_byor:
|
||||
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
|
||||
if org_member.llm_base_url:
|
||||
kwargs['llm_base_url'] = org_member.llm_base_url
|
||||
if org.v1_enabled is None:
|
||||
kwargs['v1_enabled'] = True
|
||||
|
||||
return settings
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if item and self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item)
|
||||
|
||||
with self.session_maker() as session:
|
||||
existing = None
|
||||
kwargs = {}
|
||||
if item:
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
self._encrypt_kwargs(kwargs)
|
||||
# First check if we have an existing entry in the new table
|
||||
existing = self.get_user_settings_by_keycloak_id(self.user_id, session)
|
||||
|
||||
kwargs = {
|
||||
key: value
|
||||
for key, value in kwargs.items()
|
||||
if key in UserSettings.__table__.columns
|
||||
}
|
||||
if existing:
|
||||
# Update existing entry
|
||||
for key, value in kwargs.items():
|
||||
setattr(existing, key, value)
|
||||
existing.user_version = CURRENT_USER_SETTINGS_VERSION
|
||||
session.merge(existing)
|
||||
else:
|
||||
kwargs['keycloak_user_id'] = self.user_id
|
||||
kwargs['user_version'] = CURRENT_USER_SETTINGS_VERSION
|
||||
kwargs.pop('secrets_store', None) # Don't save secrets_store to db
|
||||
settings = UserSettings(**kwargs)
|
||||
session.add(settings)
|
||||
session.commit()
|
||||
|
||||
def _get_redis_client(self):
|
||||
"""Get the Redis client from the Socket.IO manager."""
|
||||
from openhands.server.shared import sio
|
||||
|
||||
return getattr(sio.manager, 'redis', None)
|
||||
|
||||
async def _acquire_user_creation_lock(self) -> bool:
|
||||
"""Attempt to acquire a distributed lock for user creation.
|
||||
|
||||
Returns True if the lock was acquired or if Redis is unavailable (fallback to no locking).
|
||||
Returns False if another process holds the lock.
|
||||
"""
|
||||
redis_client = self._get_redis_client()
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return True # Proceed without locking if Redis is unavailable
|
||||
|
||||
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{self.user_id}'
|
||||
lock_acquired = await redis_client.set(
|
||||
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
|
||||
)
|
||||
return bool(lock_acquired)
|
||||
|
||||
async def create_default_settings(self, user_settings: UserSettings | None):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not self.user_id:
|
||||
return None
|
||||
|
||||
# Prevent duplicate settings creation using distributed lock
|
||||
if not await self._acquire_user_creation_lock():
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
|
||||
return await self.load()
|
||||
|
||||
# Only users that have specified a payment method get default settings
|
||||
if REQUIRE_PAYMENT and not await stripe_service.has_payment_method(
|
||||
self.user_id
|
||||
):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:no_payment',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
settings: Settings | None = None
|
||||
if user_settings is None:
|
||||
settings = Settings(
|
||||
language='en',
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
elif isinstance(user_settings, UserSettings):
|
||||
# Convert UserSettings (SQLAlchemy model) to Settings (Pydantic model)
|
||||
kwargs = {
|
||||
c.name: getattr(user_settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
|
||||
if settings:
|
||||
settings = await self.update_settings_with_litellm_default(settings)
|
||||
if settings is None:
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:litellm_update_failed',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
await self.store(settings)
|
||||
return settings
|
||||
|
||||
async def load_legacy_file_store_settings(self, github_user_id: str):
|
||||
if not github_user_id:
|
||||
return None
|
||||
|
||||
file_store = get_file_store(self.config.file_store, self.config.file_store_path)
|
||||
path = f'users/github/{github_user_id}/settings.json'
|
||||
|
||||
try:
|
||||
json_str = await call_sync_from_async(file_store.read, path)
|
||||
logger.info(
|
||||
'saas_settings_store:load_legacy_file_store_settings:found',
|
||||
extra={'github_user_id': github_user_id},
|
||||
)
|
||||
kwargs = json.loads(json_str)
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'saas_settings_store:load_legacy_file_store_settings:error',
|
||||
extra={'github_user_id': github_user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
def _has_custom_settings(
|
||||
self, settings: Settings, old_user_version: int | None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has custom LLM settings that should be preserved.
|
||||
Returns True if user customized either model or base_url.
|
||||
|
||||
Args:
|
||||
settings: The user's current settings
|
||||
old_user_version: The user's old settings version, if any
|
||||
|
||||
Returns:
|
||||
True if user has custom settings, False if using old defaults
|
||||
"""
|
||||
# Normalize values
|
||||
user_model = (
|
||||
settings.llm_model.strip()
|
||||
if settings.llm_model and settings.llm_model.strip()
|
||||
else None
|
||||
)
|
||||
user_base_url = (
|
||||
settings.llm_base_url.strip()
|
||||
if settings.llm_base_url and settings.llm_base_url.strip()
|
||||
else None
|
||||
)
|
||||
|
||||
# Custom base_url = definitely custom settings (BYOK)
|
||||
if user_base_url and user_base_url != LITE_LLM_API_URL:
|
||||
return True
|
||||
|
||||
# No model set = using defaults
|
||||
if not user_model:
|
||||
return False
|
||||
|
||||
# Check if model matches old version's default
|
||||
if (
|
||||
old_user_version
|
||||
and old_user_version < CURRENT_USER_SETTINGS_VERSION
|
||||
and old_user_version in USER_SETTINGS_VERSION_TO_MODEL
|
||||
):
|
||||
old_default_base = USER_SETTINGS_VERSION_TO_MODEL[old_user_version]
|
||||
user_model_base = user_model.split('/')[-1]
|
||||
if user_model_base == old_default_base:
|
||||
return False # Matches old default
|
||||
|
||||
return True # Custom model
|
||||
|
||||
async def update_settings_with_litellm_default(
|
||||
self, settings: Settings
|
||||
) -> Settings | None:
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
|
||||
# Check if user has custom settings
|
||||
has_custom = self._has_custom_settings(settings, settings.user_version)
|
||||
|
||||
# Determine model to use (needed before LiteLLM user creation)
|
||||
llm_model_to_use = (
|
||||
settings.llm_model
|
||||
if has_custom and settings.llm_model
|
||||
else get_default_litellm_model()
|
||||
)
|
||||
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(self.user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
# Get the previous max budget to prevent accidental loss.
|
||||
#
|
||||
# LiteLLM v1.80+ returns 404 for non-existent users (previously returned empty user_info)
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={self.user_id}'
|
||||
)
|
||||
user_info: dict
|
||||
if response.status_code == 404:
|
||||
# New user - doesn't exist in LiteLLM yet (v1.80+ behavior)
|
||||
user_info = {}
|
||||
else:
|
||||
# For any other status, use standard error handling
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
user_info = response_json.get('user_info') or {}
|
||||
logger.info(
|
||||
f'creating_litellm_user: {self.user_id}; prev_max_budget: {user_info.get("max_budget")}; prev_metadata: {user_info.get("metadata")}'
|
||||
)
|
||||
max_budget = user_info.get('max_budget') or DEFAULT_INITIAL_BUDGET
|
||||
spend = user_info.get('spend') or 0
|
||||
if not item:
|
||||
return None
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(self.user_id))
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
# Check if we need to migrate from user_settings
|
||||
user_settings = None
|
||||
with session_maker() as session:
|
||||
user_settings = self.get_user_settings_by_keycloak_id(
|
||||
user_settings = self._get_user_settings_by_keycloak_id(
|
||||
self.user_id, session
|
||||
)
|
||||
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
|
||||
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
|
||||
if (
|
||||
user_settings
|
||||
and user_settings.user_version < 4
|
||||
and user_settings.billing_margin
|
||||
and user_settings.billing_margin != 1.0
|
||||
):
|
||||
billing_margin = user_settings.billing_margin
|
||||
logger.info(
|
||||
'user_settings_v4_budget_upgrade',
|
||||
extra={
|
||||
'max_budget': max_budget,
|
||||
'billing_margin': billing_margin,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
max_budget *= billing_margin
|
||||
spend *= billing_margin
|
||||
user_settings.billing_margin = 1.0
|
||||
session.commit()
|
||||
|
||||
email = keycloak_user_info.get('email')
|
||||
|
||||
# We explicitly delete here to guard against odd inherited settings on upgrade.
|
||||
# We don't care if this fails with a 404
|
||||
await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [self.user_id]}
|
||||
)
|
||||
|
||||
# Create the new litellm user
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, email, max_budget, spend, llm_model_to_use
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
'duplicate_user_email',
|
||||
extra={'user_id': self.user_id, 'email': email},
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, None, max_budget, spend, llm_model_to_use
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [self.user_id],
|
||||
'email': email,
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
if user_settings:
|
||||
user = await UserStore.migrate_user(self.user_id, user_settings)
|
||||
else:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
return None
|
||||
|
||||
response_json = response.json()
|
||||
key = response_json['key']
|
||||
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:user_created',
|
||||
extra={'user_id': self.user_id},
|
||||
org_id = user.current_org_id
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item, str(org_id))
|
||||
org_member = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
|
||||
if has_custom:
|
||||
settings.llm_model = settings.llm_model or get_default_litellm_model()
|
||||
settings.llm_base_url = settings.llm_base_url or LITE_LLM_API_URL
|
||||
settings.llm_api_key = settings.llm_api_key or SecretStr(key)
|
||||
else:
|
||||
settings.llm_model = get_default_litellm_model()
|
||||
settings.llm_base_url = LITE_LLM_API_URL
|
||||
settings.llm_api_key = SecretStr(key)
|
||||
for model in (user, org, org_member):
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
settings.agent = 'CodeActAgent'
|
||||
|
||||
return settings
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
@@ -443,6 +177,9 @@ class SaasSettingsStore(SettingsStore):
|
||||
logger.debug(f'saas_settings_store.get_instance::{user_id}')
|
||||
return SaasSettingsStore(user_id, session_maker, config)
|
||||
|
||||
def _should_encrypt(self, key):
|
||||
return key in self.ENCRYPT_VALUES
|
||||
|
||||
def _decrypt_kwargs(self, kwargs: dict):
|
||||
fernet = self._fernet()
|
||||
for key, value in kwargs.items():
|
||||
@@ -486,21 +223,24 @@ class SaasSettingsStore(SettingsStore):
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
return Fernet(fernet_key)
|
||||
|
||||
def _should_encrypt(self, key: str) -> bool:
|
||||
return key in ('llm_api_key', 'llm_api_key_for_byor', 'search_api_key')
|
||||
|
||||
def _is_openhands_provider(self, item: Settings) -> bool:
|
||||
"""Check if the settings use the OpenHands provider."""
|
||||
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
|
||||
|
||||
async def _ensure_openhands_api_key(self, item: Settings) -> None:
|
||||
async def _ensure_openhands_api_key(self, item: Settings, org_id: str) -> None:
|
||||
"""Generate and set the OpenHands API key for the given settings.
|
||||
|
||||
First checks if an existing key with the OpenHands alias exists,
|
||||
and reuses it if found. Otherwise, generates a new key.
|
||||
"""
|
||||
# Generate new key if none exists
|
||||
generated_key = await self._generate_openhands_key()
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
self.user_id,
|
||||
org_id,
|
||||
None,
|
||||
{'type': 'openhands'},
|
||||
)
|
||||
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
@@ -512,83 +252,3 @@ class SaasSettingsStore(SettingsStore):
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
|
||||
async def _create_user_in_lite_llm(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
max_budget: int,
|
||||
spend: int,
|
||||
llm_model: str,
|
||||
):
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': email,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
'user_id': str(self.user_id),
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': True,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': CURRENT_USER_SETTINGS_VERSION,
|
||||
'model': llm_model,
|
||||
},
|
||||
'key_alias': f'OpenHands Cloud - user {self.user_id}',
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
async def _generate_openhands_key(self) -> str | None:
|
||||
"""Generate a new OpenHands provider key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'saas_settings_store:_generate_openhands_key:litellm_config_not_found',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
'user_id': self.user_id,
|
||||
'metadata': {'type': 'openhands'},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'saas_settings_store:_generate_openhands_key:success',
|
||||
extra={
|
||||
'user_id': self.user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': (
|
||||
key[:10] + '...' if key and len(key) > 10 else key
|
||||
),
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'saas_settings_store:_generate_openhands_key:no_key_in_response',
|
||||
extra={'user_id': self.user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'saas_settings_store:_generate_openhands_key:error',
|
||||
extra={'user_id': self.user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Boolean, Column, Identity, Integer, String
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -8,5 +10,9 @@ class SlackConversation(Base): # type: ignore
|
||||
conversation_id = Column(String, nullable=False, index=True)
|
||||
channel_id = Column(String, nullable=False)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
parent_id = Column(String, nullable=True, index=True)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_conversations')
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Identity, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Identity, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -6,6 +8,7 @@ class SlackUser(Base): # type: ignore
|
||||
__tablename__ = 'slack_users'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
slack_user_id = Column(String, nullable=False, index=True)
|
||||
slack_display_name = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -13,3 +16,6 @@ class SlackUser(Base): # type: ignore
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_users')
|
||||
|
||||
@@ -4,5 +4,4 @@ from openhands.app_server.app_conversation.sql_app_conversation_info_service imp
|
||||
|
||||
StoredConversationMetadata = _StoredConversationMetadata
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadata']
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
SQLAlchemy model for ConversationMetadataSaas.
|
||||
|
||||
This model stores the SaaS-specific metadata for conversations,
|
||||
containing only the conversation_id, user_id, and org_id.
|
||||
"""
|
||||
|
||||
from sqlalchemy import UUID as SQL_UUID
|
||||
from sqlalchemy import Column, ForeignKey, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class StoredConversationMetadataSaas(Base): # type: ignore
|
||||
"""SaaS conversation metadata model containing user and org associations."""
|
||||
|
||||
__tablename__ = 'conversation_metadata_saas'
|
||||
|
||||
conversation_id = Column(String, primary_key=True)
|
||||
user_id = Column(SQL_UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
|
||||
# Relationships
|
||||
user = relationship('User', back_populates='stored_conversation_metadata_saas')
|
||||
org = relationship('Org', back_populates='stored_conversation_metadata_saas')
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadataSaas']
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -6,6 +8,10 @@ class StoredCustomSecrets(Base): # type: ignore
|
||||
__tablename__ = 'custom_secrets'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=True, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
secret_name = Column(String, nullable=False)
|
||||
secret_value = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='user_secrets')
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -13,6 +15,7 @@ class StripeCustomer(Base): # type: ignore
|
||||
__tablename__ = 'stripe_customers'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
stripe_customer_id = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
@@ -23,3 +26,6 @@ class StripeCustomer(Base): # type: ignore
|
||||
onupdate=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='stripe_customers')
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
SQLAlchemy model for User.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import (
|
||||
UUID,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class User(Base): # type: ignore
|
||||
"""User model with organizational relationships."""
|
||||
|
||||
__tablename__ = 'user'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
current_org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=True)
|
||||
accepted_tos = Column(DateTime, nullable=True)
|
||||
enable_sound_notifications = Column(Boolean, nullable=True)
|
||||
language = Column(String, nullable=True)
|
||||
user_consents_to_analytics = Column(Boolean, nullable=True)
|
||||
email = Column(String, nullable=True)
|
||||
email_verified = Column(Boolean, nullable=True)
|
||||
|
||||
# Relationships
|
||||
role = relationship('Role', back_populates='users')
|
||||
org_members = relationship('OrgMember', back_populates='user')
|
||||
current_org = relationship('Org', back_populates='current_users')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='user'
|
||||
)
|
||||
@@ -39,3 +39,6 @@ class UserSettings(Base): # type: ignore
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
already_migrated = Column(
|
||||
Boolean, nullable=True, default=False
|
||||
) # False = not migrated, True = migrated
|
||||
|
||||
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
Store class for managing users.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
ORG_SETTINGS_VERSION,
|
||||
PERSONAL_WORKSPACE_VERSION_TO_MODEL,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.encrypt_utils import decrypt_legacy_model
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
# The max possible time to wait for another process to finish creating a user before retrying
|
||||
_REDIS_CREATE_TIMEOUT_SECONDS = 30
|
||||
# The delay to wait for another process to finish creating a user before trying to load again
|
||||
_RETRY_LOAD_DELAY_SECONDS = 2
|
||||
# Redis key prefix for user creation locks
|
||||
_REDIS_USER_CREATION_KEY_PREFIX = 'create_user:'
|
||||
|
||||
|
||||
class UserStore:
|
||||
"""Store for managing users."""
|
||||
|
||||
@staticmethod
|
||||
async def create_user(
|
||||
user_id: str,
|
||||
user_info: dict,
|
||||
role_id: Optional[int] = None,
|
||||
) -> User | None:
|
||||
"""Create a new user."""
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_name=user_info['preferred_username'],
|
||||
contact_email=user_info['email'],
|
||||
v1_enabled=True,
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
settings = await UserStore.create_default_settings(
|
||||
org_id=str(org.id), user_id=user_id
|
||||
)
|
||||
|
||||
if not settings:
|
||||
return None
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_settings(settings)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=role_id,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_settings(settings)
|
||||
# avoid setting org member llm fields to use org defaults on user creation
|
||||
del org_member_kwargs['llm_model']
|
||||
del org_member_kwargs['llm_base_url']
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # owner of your own org.
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def _get_redis_client():
|
||||
"""Get the Redis client from the Socket.IO manager."""
|
||||
from openhands.server.shared import sio
|
||||
|
||||
return getattr(sio.manager, 'redis', None)
|
||||
|
||||
@staticmethod
|
||||
async def _acquire_user_creation_lock(user_id: str) -> bool:
|
||||
"""Attempt to acquire a distributed lock for user creation.
|
||||
|
||||
Returns True if the lock was acquired or if Redis is unavailable (fallback to no locking).
|
||||
Returns False if another process holds the lock.
|
||||
"""
|
||||
redis_client = UserStore._get_redis_client()
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True # Proceed without locking if Redis is unavailable
|
||||
|
||||
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{user_id}'
|
||||
lock_acquired = await redis_client.set(
|
||||
user_key, 1, nx=True, ex=_REDIS_CREATE_TIMEOUT_SECONDS
|
||||
)
|
||||
return bool(lock_acquired)
|
||||
|
||||
@staticmethod
|
||||
async def migrate_user(
|
||||
user_id: str,
|
||||
user_settings: UserSettings,
|
||||
user_info: dict,
|
||||
) -> User:
|
||||
if not user_id or not user_settings:
|
||||
return None
|
||||
|
||||
kwargs = decrypt_legacy_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
],
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
org_version=user_settings.user_version,
|
||||
contact_name=user_info['username'],
|
||||
contact_email=user_info['email'],
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
await LiteLlmManager.migrate_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
)
|
||||
|
||||
custom_settings = UserStore._has_custom_settings(
|
||||
decrypted_user_settings, user_settings.user_version
|
||||
)
|
||||
|
||||
# avoids circular reference. This migrate method is temprorary until all users are migrated.
|
||||
from integrations.stripe_service import migrate_customer
|
||||
|
||||
await migrate_customer(session, user_id, org)
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_user_settings(decrypted_user_settings)
|
||||
org_kwargs.pop('id', None)
|
||||
|
||||
# if user has custom settings, set org defaults to current version
|
||||
if custom_settings:
|
||||
org_kwargs['default_llm_model'] = get_default_litellm_model()
|
||||
org_kwargs['llm_base_url'] = LITE_LLM_API_URL
|
||||
org_kwargs['org_version'] = ORG_SETTINGS_VERSION
|
||||
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_user_settings(
|
||||
decrypted_user_settings
|
||||
)
|
||||
user_kwargs.pop('id', None)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=None,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_user_settings(
|
||||
decrypted_user_settings
|
||||
)
|
||||
|
||||
# if the user did not have custom settings in the old model,
|
||||
# then use the org defaults by not setting org_member fields
|
||||
if not custom_settings:
|
||||
del org_member_kwargs['llm_model']
|
||||
del org_member_kwargs['llm_base_url']
|
||||
del org_member_kwargs['llm_api_key_for_byor']
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # owner of your own org.
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
session.add(org_member)
|
||||
|
||||
# Mark the old user_settings as migrated instead of deleting
|
||||
user_settings.already_migrated = True
|
||||
session.merge(user_settings)
|
||||
session.flush()
|
||||
|
||||
# need to migrate conversation metadata
|
||||
session.execute(
|
||||
text("""
|
||||
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||
SELECT
|
||||
conversation_id,
|
||||
:user_id,
|
||||
:user_id
|
||||
FROM conversation_metadata
|
||||
WHERE user_id = :user_id
|
||||
"""),
|
||||
{'user_id': user_id},
|
||||
)
|
||||
|
||||
# Update org_id for tables that had org_id added
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
|
||||
# Update stripe_customers
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update slack_users
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update slack_conversation
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update api_keys
|
||||
session.execute(
|
||||
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update custom_secrets
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update billing_sessions
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID (sync version).
|
||||
|
||||
Note: This method uses call_async_from_sync internally which creates a new
|
||||
event loop. If you're already in an async context, use get_user_by_id_async
|
||||
instead to avoid event loop conflicts.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Check if we need to migrate from user_settings
|
||||
while not call_async_from_sync(
|
||||
UserStore._acquire_user_creation_lock, GENERAL_TIMEOUT, user_id
|
||||
):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
call_async_from_sync(
|
||||
asyncio.sleep, GENERAL_TIMEOUT, _RETRY_LOAD_DELAY_SECONDS
|
||||
)
|
||||
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = call_async_from_sync(
|
||||
token_manager.get_user_info_from_user_id,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
)
|
||||
user = call_async_from_sync(
|
||||
UserStore.migrate_user,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_id_async(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID (async version).
|
||||
|
||||
This is the preferred method when calling from an async context as it
|
||||
avoids event loop conflicts that can occur with the sync version.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Check if we need to migrate from user_settings
|
||||
while not await UserStore._acquire_user_creation_lock(user_id):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
|
||||
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id)
|
||||
user = await UserStore.migrate_user(
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
"""List all users."""
|
||||
with session_maker() as session:
|
||||
return session.query(User).all()
|
||||
|
||||
# Prevent circular imports
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
@staticmethod
|
||||
async def create_default_settings(
|
||||
org_id: str, user_id: str, create_user: bool = True
|
||||
) -> Optional['Settings']:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:start',
|
||||
extra={'org_id': org_id, 'user_id': user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not org_id:
|
||||
return None
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
settings = Settings(language='en', enable_proactive_conversation_starters=True)
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
settings = await LiteLlmManager.create_entries(
|
||||
org_id, user_id, settings, create_user
|
||||
)
|
||||
if not settings:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:litellm_create_failed',
|
||||
extra={'org_id': org_id},
|
||||
)
|
||||
return None
|
||||
|
||||
return settings
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: 'Settings'):
|
||||
kwargs = {
|
||||
normalized: getattr(settings, normalized)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {
|
||||
normalized: getattr(user_settings, normalized)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _has_custom_settings(
|
||||
user_settings: UserSettings, old_user_version: int | None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has custom LLM settings that should be preserved.
|
||||
Returns True if user customized either model or base_url.
|
||||
|
||||
Args:
|
||||
settings: The user's current settings
|
||||
old_user_version: The user's old settings version, if any
|
||||
|
||||
Returns:
|
||||
True if user has custom settings, False if using old defaults
|
||||
"""
|
||||
# Normalize values
|
||||
user_model = (
|
||||
user_settings.llm_model.strip() or None if user_settings.llm_model else None
|
||||
)
|
||||
user_base_url = (
|
||||
user_settings.llm_base_url.strip() or None
|
||||
if user_settings.llm_base_url
|
||||
else None
|
||||
)
|
||||
|
||||
# Custom base_url = definitely custom settings (BYOK)
|
||||
if user_base_url and user_base_url != LITE_LLM_API_URL:
|
||||
return True
|
||||
|
||||
# No model set = using defaults
|
||||
if not user_model:
|
||||
return False
|
||||
|
||||
# Check if model matches old version's default
|
||||
if (
|
||||
old_user_version
|
||||
and old_user_version <= ORG_SETTINGS_VERSION
|
||||
and old_user_version in PERSONAL_WORKSPACE_VERSION_TO_MODEL
|
||||
):
|
||||
old_default_base = PERSONAL_WORKSPACE_VERSION_TO_MODEL[old_user_version]
|
||||
user_model_base = user_model.split('/')[-1]
|
||||
if user_model_base == old_default_base:
|
||||
return False # Matches old default
|
||||
|
||||
return True # Custom model
|
||||
@@ -21,7 +21,7 @@ from sqlalchemy import text
|
||||
# Add the parent directory to the path so we can import from storage
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from server.auth.token_manager import get_keycloak_admin
|
||||
from storage.database import engine
|
||||
from storage.database import get_engine
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -85,7 +85,7 @@ def get_recent_conversations(minutes: int = 60) -> List[Dict[str, Any]]:
|
||||
created_at DESC
|
||||
""")
|
||||
|
||||
with engine.connect() as connection:
|
||||
with get_engine().connect() as connection:
|
||||
result = connection.execute(query, {'minutes': minutes})
|
||||
conversations = [
|
||||
{
|
||||
|
||||
@@ -13,7 +13,7 @@ from sqlalchemy import text
|
||||
# Add the parent directory to the path so we can import from storage
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from storage.database import engine
|
||||
from storage.database import get_engine
|
||||
|
||||
|
||||
def test_conversation_count_query():
|
||||
@@ -29,6 +29,8 @@ def test_conversation_count_query():
|
||||
user_id
|
||||
""")
|
||||
|
||||
engine = get_engine()
|
||||
|
||||
with engine.connect() as connection:
|
||||
count_result = connection.execute(count_query)
|
||||
user_counts = [
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import (
|
||||
UserVersionUpgradeProcessor,
|
||||
)
|
||||
from server.constants import ORG_SETTINGS_VERSION
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
@@ -15,11 +14,16 @@ from storage.conversation_work import ConversationWork
|
||||
from storage.device_code import DeviceCode # noqa: F401
|
||||
from storage.feedback import Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -68,7 +72,6 @@ def add_minimal_fixtures(session_maker):
|
||||
session.add(
|
||||
StoredConversationMetadata(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id='mock-user-id',
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
last_updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
accumulated_cost=5.25,
|
||||
@@ -77,6 +80,13 @@ def add_minimal_fixtures(session_maker):
|
||||
total_tokens=750,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredConversationMetadataSaas(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredOfflineToken(
|
||||
user_id='mock-user-id',
|
||||
@@ -85,7 +95,38 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
|
||||
session.add(
|
||||
Org(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
name='mock-org',
|
||||
org_version=ORG_SETTINGS_VERSION,
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
Role(
|
||||
id=1,
|
||||
name='admin',
|
||||
rank=1,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
User(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
current_org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_consents_to_analytics=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
OrgMember(
|
||||
org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
role_id=1,
|
||||
llm_api_key='mock-api-key',
|
||||
status='active',
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id='mock-user-id',
|
||||
@@ -94,13 +135,6 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-10'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
UserSettings(
|
||||
keycloak_user_id='mock-user-id',
|
||||
user_consents_to_analytics=True,
|
||||
user_version=CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
ConversationWork(
|
||||
conversation_id='mock-conversation-id',
|
||||
@@ -109,17 +143,6 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
maintenance_task = MaintenanceTask(
|
||||
status=MaintenanceTaskStatus.PENDING,
|
||||
)
|
||||
maintenance_task.set_processor(
|
||||
UserVersionUpgradeProcessor(
|
||||
user_ids=['mock-user-id'],
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
session.add(maintenance_task)
|
||||
session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from integrations.jira.jira_manager import JiraManager
|
||||
from integrations.jira.jira_view import (
|
||||
JiraExistingConversationView,
|
||||
JiraNewConversationView,
|
||||
)
|
||||
from integrations.models import JobContext
|
||||
@@ -194,21 +193,6 @@ def new_conversation_view(
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def existing_conversation_view(
|
||||
sample_job_context, sample_user_auth, sample_jira_user, sample_jira_workspace
|
||||
):
|
||||
"""JiraExistingConversationView instance for testing"""
|
||||
return JiraExistingConversationView(
|
||||
job_context=sample_job_context,
|
||||
saas_user_auth=sample_user_auth,
|
||||
jira_user=sample_jira_user,
|
||||
jira_workspace=sample_jira_workspace,
|
||||
selected_repo='test/repo1',
|
||||
conversation_id='conv-123',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent_loop_info():
|
||||
"""Mock agent loop info"""
|
||||
|
||||
@@ -2,17 +2,12 @@
|
||||
Unit tests for JiraManager.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from integrations.jira.jira_manager import JiraManager
|
||||
from integrations.jira.jira_types import JiraViewInterface
|
||||
from integrations.jira.jira_view import (
|
||||
JiraExistingConversationView,
|
||||
JiraNewConversationView,
|
||||
)
|
||||
from integrations.models import Message, SourceType
|
||||
@@ -125,124 +120,83 @@ class TestGetRepositories:
|
||||
mock_client.get_repositories.assert_called_once()
|
||||
|
||||
|
||||
class TestValidateRequest:
|
||||
"""Test webhook request validation."""
|
||||
class TestGetWorkspaceNameFromPayload:
|
||||
"""Test workspace name extraction from webhook payload."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_request_success(
|
||||
def test_get_workspace_name_from_comment_created_payload(
|
||||
self,
|
||||
jira_manager,
|
||||
mock_token_manager,
|
||||
sample_jira_workspace,
|
||||
sample_comment_webhook_payload,
|
||||
):
|
||||
"""Test successful webhook validation."""
|
||||
# Setup mocks
|
||||
mock_token_manager.decrypt_text.return_value = 'test_secret'
|
||||
jira_manager.integration_store.get_workspace_by_name.return_value = (
|
||||
sample_jira_workspace
|
||||
"""Test extracting workspace name from comment_created webhook."""
|
||||
workspace_name = jira_manager.get_workspace_name_from_payload(
|
||||
sample_comment_webhook_payload
|
||||
)
|
||||
|
||||
# Create mock request
|
||||
body = json.dumps(sample_comment_webhook_payload).encode()
|
||||
signature = hmac.new('test_secret'.encode(), body, hashlib.sha256).hexdigest()
|
||||
assert workspace_name == 'test.atlassian.net'
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {'x-hub-signature': f'sha256={signature}'}
|
||||
mock_request.body = AsyncMock(return_value=body)
|
||||
mock_request.json = AsyncMock(return_value=sample_comment_webhook_payload)
|
||||
|
||||
is_valid, returned_signature, payload = await jira_manager.validate_request(
|
||||
mock_request
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
assert returned_signature == signature
|
||||
assert payload == sample_comment_webhook_payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_request_missing_signature(
|
||||
self, jira_manager, sample_comment_webhook_payload
|
||||
):
|
||||
"""Test webhook validation with missing signature."""
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {}
|
||||
mock_request.body = AsyncMock(return_value=b'{}')
|
||||
mock_request.json = AsyncMock(return_value=sample_comment_webhook_payload)
|
||||
|
||||
is_valid, signature, payload = await jira_manager.validate_request(mock_request)
|
||||
|
||||
assert is_valid is False
|
||||
assert signature is None
|
||||
assert payload is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_request_workspace_not_found(
|
||||
self, jira_manager, sample_comment_webhook_payload
|
||||
):
|
||||
"""Test webhook validation when workspace is not found."""
|
||||
jira_manager.integration_store.get_workspace_by_name.return_value = None
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {'x-hub-signature': 'sha256=test_signature'}
|
||||
mock_request.body = AsyncMock(return_value=b'{}')
|
||||
mock_request.json = AsyncMock(return_value=sample_comment_webhook_payload)
|
||||
|
||||
is_valid, signature, payload = await jira_manager.validate_request(mock_request)
|
||||
|
||||
assert is_valid is False
|
||||
assert signature is None
|
||||
assert payload is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_request_workspace_inactive(
|
||||
def test_get_workspace_name_from_issue_updated_payload(
|
||||
self,
|
||||
jira_manager,
|
||||
mock_token_manager,
|
||||
sample_jira_workspace,
|
||||
sample_comment_webhook_payload,
|
||||
sample_issue_update_webhook_payload,
|
||||
):
|
||||
"""Test webhook validation when workspace is inactive."""
|
||||
sample_jira_workspace.status = 'inactive'
|
||||
jira_manager.integration_store.get_workspace_by_name.return_value = (
|
||||
sample_jira_workspace
|
||||
"""Test extracting workspace name from jira:issue_updated webhook."""
|
||||
workspace_name = jira_manager.get_workspace_name_from_payload(
|
||||
sample_issue_update_webhook_payload
|
||||
)
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {'x-hub-signature': 'sha256=test_signature'}
|
||||
mock_request.body = AsyncMock(return_value=b'{}')
|
||||
mock_request.json = AsyncMock(return_value=sample_comment_webhook_payload)
|
||||
assert workspace_name == 'jira.company.com'
|
||||
|
||||
is_valid, signature, payload = await jira_manager.validate_request(mock_request)
|
||||
|
||||
assert is_valid is False
|
||||
assert signature is None
|
||||
assert payload is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_request_invalid_signature(
|
||||
def test_get_workspace_name_from_unknown_event(
|
||||
self,
|
||||
jira_manager,
|
||||
mock_token_manager,
|
||||
sample_jira_workspace,
|
||||
sample_comment_webhook_payload,
|
||||
):
|
||||
"""Test webhook validation with invalid signature."""
|
||||
mock_token_manager.decrypt_text.return_value = 'test_secret'
|
||||
jira_manager.integration_store.get_workspace_by_name.return_value = (
|
||||
sample_jira_workspace
|
||||
)
|
||||
"""Test extracting workspace name from unknown webhook event."""
|
||||
payload = {
|
||||
'webhookEvent': 'unknown_event',
|
||||
'some_data': {'self': 'https://example.atlassian.net/rest/api/2/something'},
|
||||
}
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {'x-hub-signature': 'sha256=invalid_signature'}
|
||||
mock_request.body = AsyncMock(return_value=b'{}')
|
||||
mock_request.json = AsyncMock(return_value=sample_comment_webhook_payload)
|
||||
workspace_name = jira_manager.get_workspace_name_from_payload(payload)
|
||||
|
||||
is_valid, signature, payload = await jira_manager.validate_request(mock_request)
|
||||
assert workspace_name is None
|
||||
|
||||
assert is_valid is False
|
||||
assert signature is None
|
||||
assert payload is None
|
||||
def test_get_workspace_name_with_missing_author_self(
|
||||
self,
|
||||
jira_manager,
|
||||
):
|
||||
"""Test extracting workspace name when author self URL is missing."""
|
||||
payload = {
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {
|
||||
'body': 'Test comment',
|
||||
'author': {
|
||||
'emailAddress': 'user@test.com',
|
||||
'displayName': 'Test User',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
workspace_name = jira_manager.get_workspace_name_from_payload(payload)
|
||||
|
||||
assert workspace_name is None
|
||||
|
||||
def test_get_workspace_name_with_missing_user_self(
|
||||
self,
|
||||
jira_manager,
|
||||
):
|
||||
"""Test extracting workspace name when user self URL is missing."""
|
||||
payload = {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'user': {
|
||||
'emailAddress': 'user@test.com',
|
||||
'displayName': 'Test User',
|
||||
},
|
||||
}
|
||||
|
||||
workspace_name = jira_manager.get_workspace_name_from_payload(payload)
|
||||
|
||||
assert workspace_name is None
|
||||
|
||||
|
||||
class TestParseWebhook:
|
||||
@@ -252,7 +206,11 @@ class TestParseWebhook:
|
||||
self, jira_manager, sample_comment_webhook_payload
|
||||
):
|
||||
"""Test parsing comment creation webhook."""
|
||||
job_context = jira_manager.parse_webhook(sample_comment_webhook_payload)
|
||||
message = Message(
|
||||
source=SourceType.JIRA,
|
||||
message={'payload': sample_comment_webhook_payload},
|
||||
)
|
||||
job_context = jira_manager.parse_webhook(message)
|
||||
|
||||
assert job_context is not None
|
||||
assert job_context.issue_id == '12345'
|
||||
@@ -264,7 +222,7 @@ class TestParseWebhook:
|
||||
assert job_context.base_api_url == 'https://test.atlassian.net'
|
||||
|
||||
def test_parse_webhook_comment_without_mention(self, jira_manager):
|
||||
"""Test parsing comment without @openhands mention."""
|
||||
"""Test parsing comment without @openhands mention raises error."""
|
||||
payload = {
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {
|
||||
@@ -272,6 +230,7 @@ class TestParseWebhook:
|
||||
'author': {
|
||||
'emailAddress': 'user@company.com',
|
||||
'displayName': 'Test User',
|
||||
'accountId': 'user456',
|
||||
'self': 'https://jira.company.com/rest/api/2/user?username=testuser',
|
||||
},
|
||||
},
|
||||
@@ -282,14 +241,19 @@ class TestParseWebhook:
|
||||
},
|
||||
}
|
||||
|
||||
job_context = jira_manager.parse_webhook(payload)
|
||||
assert job_context is None
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
with pytest.raises(ValueError, match='Unrecognized jira event'):
|
||||
jira_manager.parse_webhook(message)
|
||||
|
||||
def test_parse_webhook_issue_update_with_openhands_label(
|
||||
self, jira_manager, sample_issue_update_webhook_payload
|
||||
):
|
||||
"""Test parsing issue update with openhands label."""
|
||||
job_context = jira_manager.parse_webhook(sample_issue_update_webhook_payload)
|
||||
message = Message(
|
||||
source=SourceType.JIRA,
|
||||
message={'payload': sample_issue_update_webhook_payload},
|
||||
)
|
||||
job_context = jira_manager.parse_webhook(message)
|
||||
|
||||
assert job_context is not None
|
||||
assert job_context.issue_id == '12345'
|
||||
@@ -299,7 +263,7 @@ class TestParseWebhook:
|
||||
assert job_context.display_name == 'Test User'
|
||||
|
||||
def test_parse_webhook_issue_update_without_openhands_label(self, jira_manager):
|
||||
"""Test parsing issue update without openhands label."""
|
||||
"""Test parsing issue update without openhands label raises error."""
|
||||
payload = {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {'items': [{'field': 'labels', 'toString': 'bug,urgent'}]},
|
||||
@@ -311,12 +275,14 @@ class TestParseWebhook:
|
||||
'user': {
|
||||
'emailAddress': 'user@company.com',
|
||||
'displayName': 'Test User',
|
||||
'accountId': 'user456',
|
||||
'self': 'https://jira.company.com/rest/api/2/user?username=testuser',
|
||||
},
|
||||
}
|
||||
|
||||
job_context = jira_manager.parse_webhook(payload)
|
||||
assert job_context is None
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
with pytest.raises(ValueError, match='Unrecognized jira event'):
|
||||
jira_manager.parse_webhook(message)
|
||||
|
||||
def test_parse_webhook_unsupported_event(self, jira_manager):
|
||||
"""Test parsing webhook with unsupported event."""
|
||||
@@ -325,8 +291,9 @@ class TestParseWebhook:
|
||||
'issue': {'id': '12345', 'key': 'PROJ-123'},
|
||||
}
|
||||
|
||||
job_context = jira_manager.parse_webhook(payload)
|
||||
assert job_context is None
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
with pytest.raises(ValueError, match='Unrecognized jira event'):
|
||||
jira_manager.parse_webhook(message)
|
||||
|
||||
def test_parse_webhook_missing_required_fields(self, jira_manager):
|
||||
"""Test parsing webhook with missing required fields."""
|
||||
@@ -347,7 +314,8 @@ class TestParseWebhook:
|
||||
},
|
||||
}
|
||||
|
||||
job_context = jira_manager.parse_webhook(payload)
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
job_context = jira_manager.parse_webhook(message)
|
||||
assert job_context is None
|
||||
|
||||
|
||||
@@ -374,12 +342,15 @@ class TestReceiveMessage:
|
||||
jira_manager.get_issue_details = AsyncMock(
|
||||
return_value=('Test Title', 'Test Description')
|
||||
)
|
||||
jira_manager.is_job_requested = AsyncMock(return_value=True)
|
||||
jira_manager.is_repository_specified = AsyncMock(return_value=True)
|
||||
jira_manager.start_job = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'integrations.jira.jira_manager.JiraFactory.create_jira_view_from_payload'
|
||||
) as mock_factory:
|
||||
) as mock_factory, patch(
|
||||
'integrations.jira.jira_manager.JiraFactory.is_ticket_comment',
|
||||
return_value=True,
|
||||
):
|
||||
mock_view = MagicMock(spec=JiraViewInterface)
|
||||
mock_factory.return_value = mock_view
|
||||
|
||||
@@ -537,23 +508,14 @@ class TestReceiveMessage:
|
||||
jira_manager._send_error_comment.assert_called_once()
|
||||
|
||||
|
||||
class TestIsJobRequested:
|
||||
"""Test job request validation."""
|
||||
class TestIsRepositorySpecified:
|
||||
"""Test repository specification validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_job_requested_existing_conversation(self, jira_manager):
|
||||
"""Test job request validation for existing conversation."""
|
||||
mock_view = MagicMock(spec=JiraExistingConversationView)
|
||||
message = Message(source=SourceType.JIRA, message={})
|
||||
|
||||
result = await jira_manager.is_job_requested(message, mock_view)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_job_requested_new_conversation_with_repo_match(
|
||||
async def test_is_repository_specified_new_conversation_with_repo_match(
|
||||
self, jira_manager, sample_job_context, sample_user_auth
|
||||
):
|
||||
"""Test job request validation for new conversation with repository match."""
|
||||
"""Test repository validation for new conversation with repository match."""
|
||||
mock_view = MagicMock(spec=JiraNewConversationView)
|
||||
mock_view.saas_user_auth = sample_user_auth
|
||||
mock_view.job_context = sample_job_context
|
||||
@@ -575,16 +537,16 @@ class TestIsJobRequested:
|
||||
mock_filter.return_value = (True, mock_repos)
|
||||
|
||||
message = Message(source=SourceType.JIRA, message={})
|
||||
result = await jira_manager.is_job_requested(message, mock_view)
|
||||
result = await jira_manager.is_repository_specified(message, mock_view)
|
||||
|
||||
assert result is True
|
||||
assert mock_view.selected_repo == 'company/repo'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_job_requested_new_conversation_no_repo_match(
|
||||
async def test_is_repository_specified_new_conversation_no_repo_match(
|
||||
self, jira_manager, sample_job_context, sample_user_auth
|
||||
):
|
||||
"""Test job request validation for new conversation without repository match."""
|
||||
"""Test repository validation for new conversation without repository match."""
|
||||
mock_view = MagicMock(spec=JiraNewConversationView)
|
||||
mock_view.saas_user_auth = sample_user_auth
|
||||
mock_view.job_context = sample_job_context
|
||||
@@ -607,14 +569,16 @@ class TestIsJobRequested:
|
||||
mock_filter.return_value = (False, [])
|
||||
|
||||
message = Message(source=SourceType.JIRA, message={})
|
||||
result = await jira_manager.is_job_requested(message, mock_view)
|
||||
result = await jira_manager.is_repository_specified(message, mock_view)
|
||||
|
||||
assert result is False
|
||||
jira_manager._send_repo_selection_comment.assert_called_once_with(mock_view)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_job_requested_exception(self, jira_manager, sample_user_auth):
|
||||
"""Test job request validation when an exception occurs."""
|
||||
async def test_is_repository_specified_exception(
|
||||
self, jira_manager, sample_user_auth
|
||||
):
|
||||
"""Test repository validation when an exception occurs."""
|
||||
mock_view = MagicMock(spec=JiraNewConversationView)
|
||||
mock_view.saas_user_auth = sample_user_auth
|
||||
jira_manager._get_repositories = AsyncMock(
|
||||
@@ -622,7 +586,7 @@ class TestIsJobRequested:
|
||||
)
|
||||
|
||||
message = Message(source=SourceType.JIRA, message={})
|
||||
result = await jira_manager.is_job_requested(message, mock_view)
|
||||
result = await jira_manager.is_repository_specified(message, mock_view)
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -659,33 +623,6 @@ class TestStartJob:
|
||||
mock_register.assert_called_once()
|
||||
jira_manager.send_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_success_existing_conversation(
|
||||
self, jira_manager, sample_jira_workspace
|
||||
):
|
||||
"""Test successful job start for existing conversation."""
|
||||
mock_view = MagicMock(spec=JiraExistingConversationView)
|
||||
mock_view.jira_user = MagicMock()
|
||||
mock_view.jira_user.keycloak_user_id = 'test_user'
|
||||
mock_view.job_context = MagicMock()
|
||||
mock_view.job_context.issue_key = 'PROJ-123'
|
||||
mock_view.jira_workspace = sample_jira_workspace
|
||||
mock_view.create_or_update_conversation = AsyncMock(return_value='conv_123')
|
||||
mock_view.get_response_msg = MagicMock(return_value='Job started successfully')
|
||||
|
||||
jira_manager.send_message = AsyncMock()
|
||||
jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
|
||||
with patch(
|
||||
'integrations.jira.jira_manager.register_callback_processor'
|
||||
) as mock_register:
|
||||
await jira_manager.start_job(mock_view)
|
||||
|
||||
mock_view.create_or_update_conversation.assert_called_once()
|
||||
# Should not register callback for existing conversation
|
||||
mock_register.assert_not_called()
|
||||
jira_manager.send_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_missing_settings_error(
|
||||
self, jira_manager, sample_jira_workspace
|
||||
|
||||
@@ -2,17 +2,15 @@
|
||||
Tests for Jira view classes and factory.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.jira.jira_types import StartingConvoException
|
||||
from integrations.jira.jira_view import (
|
||||
JiraExistingConversationView,
|
||||
JiraFactory,
|
||||
JiraNewConversationView,
|
||||
)
|
||||
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from integrations.models import Message, SourceType
|
||||
|
||||
|
||||
class TestJiraNewConversationView:
|
||||
@@ -80,162 +78,9 @@ class TestJiraNewConversationView:
|
||||
assert 'conv-123' in response
|
||||
|
||||
|
||||
class TestJiraExistingConversationView:
|
||||
"""Tests for JiraExistingConversationView"""
|
||||
|
||||
def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = existing_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
assert instructions == ''
|
||||
assert 'TEST-123' in user_msg
|
||||
assert 'Test Issue' in user_msg
|
||||
assert 'Fix this bug @openhands' in user_msg
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
|
||||
@patch('integrations.jira.jira_view.conversation_manager')
|
||||
@patch('integrations.jira.jira_view.get_final_agent_observation')
|
||||
async def test_create_or_update_conversation_success(
|
||||
self,
|
||||
mock_get_observation,
|
||||
mock_conversation_manager,
|
||||
mock_setup_init,
|
||||
mock_store_impl,
|
||||
existing_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_conversation_store,
|
||||
mock_conversation_init_data,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test successful existing conversation update"""
|
||||
# Setup mocks
|
||||
mock_store_impl.return_value = mock_conversation_store
|
||||
mock_setup_init.return_value = mock_conversation_init_data
|
||||
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
|
||||
return_value=mock_agent_loop_info
|
||||
)
|
||||
mock_conversation_manager.send_event_to_conversation = AsyncMock()
|
||||
|
||||
# Mock agent observation with RUNNING state
|
||||
mock_observation = MagicMock()
|
||||
mock_observation.agent_state = AgentState.RUNNING
|
||||
mock_get_observation.return_value = [mock_observation]
|
||||
|
||||
result = await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
assert result == 'conv-123'
|
||||
mock_conversation_manager.send_event_to_conversation.assert_called_once()
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
async def test_create_or_update_conversation_no_metadata(
|
||||
self, mock_store_impl, existing_conversation_view, mock_jinja_env
|
||||
):
|
||||
"""Test conversation update with no metadata"""
|
||||
mock_store = AsyncMock()
|
||||
mock_store.get_metadata.side_effect = FileNotFoundError(
|
||||
'No such file or directory'
|
||||
)
|
||||
mock_store_impl.return_value = mock_store
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Conversation no longer exists'
|
||||
):
|
||||
await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
|
||||
@patch('integrations.jira.jira_view.conversation_manager')
|
||||
@patch('integrations.jira.jira_view.get_final_agent_observation')
|
||||
async def test_create_or_update_conversation_loading_state(
|
||||
self,
|
||||
mock_get_observation,
|
||||
mock_conversation_manager,
|
||||
mock_setup_init,
|
||||
mock_store_impl,
|
||||
existing_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_conversation_store,
|
||||
mock_conversation_init_data,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test conversation update with loading state"""
|
||||
mock_store_impl.return_value = mock_conversation_store
|
||||
mock_setup_init.return_value = mock_conversation_init_data
|
||||
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
|
||||
return_value=mock_agent_loop_info
|
||||
)
|
||||
|
||||
# Mock agent observation with LOADING state
|
||||
mock_observation = MagicMock()
|
||||
mock_observation.agent_state = AgentState.LOADING
|
||||
mock_get_observation.return_value = [mock_observation]
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Conversation is still starting'
|
||||
):
|
||||
await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
async def test_create_or_update_conversation_failure(
|
||||
self, mock_store_impl, existing_conversation_view, mock_jinja_env
|
||||
):
|
||||
"""Test conversation update failure"""
|
||||
mock_store_impl.side_effect = Exception('Store error')
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Failed to create conversation'
|
||||
):
|
||||
await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
def test_get_response_msg(self, existing_conversation_view):
|
||||
"""Test get_response_msg method"""
|
||||
response = existing_conversation_view.get_response_msg()
|
||||
|
||||
assert "I'm on it!" in response
|
||||
assert 'Test User' in response
|
||||
assert 'continue tracking my progress here' in response
|
||||
assert 'conv-123' in response
|
||||
|
||||
|
||||
class TestJiraFactory:
|
||||
"""Tests for JiraFactory"""
|
||||
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_create_jira_view_from_payload_existing_conversation(
|
||||
self,
|
||||
mock_store,
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
jira_conversation,
|
||||
):
|
||||
"""Test factory creating existing conversation view"""
|
||||
mock_store.get_user_conversations_by_issue_id = AsyncMock(
|
||||
return_value=jira_conversation
|
||||
)
|
||||
|
||||
view = await JiraFactory.create_jira_view_from_payload(
|
||||
sample_job_context,
|
||||
sample_user_auth,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
)
|
||||
|
||||
assert isinstance(view, JiraExistingConversationView)
|
||||
assert view.conversation_id == 'conv-123'
|
||||
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_create_jira_view_from_payload_new_conversation(
|
||||
self,
|
||||
@@ -341,83 +186,364 @@ class TestJiraViewEdgeCases:
|
||||
):
|
||||
await new_conversation_view.create_or_update_conversation(mock_jinja_env)
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
|
||||
@patch('integrations.jira.jira_view.conversation_manager')
|
||||
@patch('integrations.jira.jira_view.get_final_agent_observation')
|
||||
async def test_existing_conversation_empty_observations(
|
||||
self,
|
||||
mock_get_observation,
|
||||
mock_conversation_manager,
|
||||
mock_setup_init,
|
||||
mock_store_impl,
|
||||
existing_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_conversation_store,
|
||||
mock_conversation_init_data,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test existing conversation with empty observations"""
|
||||
mock_store_impl.return_value = mock_conversation_store
|
||||
mock_setup_init.return_value = mock_conversation_init_data
|
||||
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
|
||||
return_value=mock_agent_loop_info
|
||||
)
|
||||
mock_get_observation.return_value = [] # Empty observations
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Conversation is still starting'
|
||||
):
|
||||
await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
def test_new_conversation_view_attributes(self, new_conversation_view):
|
||||
"""Test new conversation view attribute access"""
|
||||
assert new_conversation_view.job_context.issue_key == 'TEST-123'
|
||||
assert new_conversation_view.selected_repo == 'test/repo1'
|
||||
assert new_conversation_view.conversation_id == 'conv-123'
|
||||
|
||||
def test_existing_conversation_view_attributes(self, existing_conversation_view):
|
||||
"""Test existing conversation view attribute access"""
|
||||
assert existing_conversation_view.job_context.issue_key == 'TEST-123'
|
||||
assert existing_conversation_view.selected_repo == 'test/repo1'
|
||||
assert existing_conversation_view.conversation_id == 'conv-123'
|
||||
|
||||
@patch('integrations.jira.jira_view.ConversationStoreImpl.get_instance')
|
||||
@patch('integrations.jira.jira_view.setup_init_conversation_settings')
|
||||
@patch('integrations.jira.jira_view.conversation_manager')
|
||||
@patch('integrations.jira.jira_view.get_final_agent_observation')
|
||||
async def test_existing_conversation_message_send_failure(
|
||||
self,
|
||||
mock_get_observation,
|
||||
mock_conversation_manager,
|
||||
mock_setup_init,
|
||||
mock_store_impl,
|
||||
existing_conversation_view,
|
||||
mock_jinja_env,
|
||||
mock_conversation_store,
|
||||
mock_conversation_init_data,
|
||||
mock_agent_loop_info,
|
||||
):
|
||||
"""Test existing conversation when message sending fails"""
|
||||
mock_store_impl.return_value = mock_conversation_store
|
||||
mock_setup_init.return_value = mock_conversation_init_data
|
||||
mock_conversation_manager.maybe_start_agent_loop = AsyncMock(
|
||||
return_value=mock_agent_loop_info
|
||||
)
|
||||
mock_conversation_manager.send_event_to_conversation = AsyncMock(
|
||||
side_effect=Exception('Send error')
|
||||
)
|
||||
class TestJiraFactoryIsLabeledTicket:
|
||||
"""Parameterized tests for JiraFactory.is_labeled_ticket method."""
|
||||
|
||||
# Mock agent observation with RUNNING state
|
||||
mock_observation = MagicMock()
|
||||
mock_observation.agent_state = AgentState.RUNNING
|
||||
mock_get_observation.return_value = [mock_observation]
|
||||
@pytest.mark.parametrize(
|
||||
'payload,expected',
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'openhands'}]
|
||||
},
|
||||
},
|
||||
True,
|
||||
id='issue_updated_with_openhands_label',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [
|
||||
{'field': 'labels', 'toString': 'bug'},
|
||||
{'field': 'labels', 'toString': 'openhands'},
|
||||
]
|
||||
},
|
||||
},
|
||||
True,
|
||||
id='issue_updated_with_multiple_labels_including_openhands',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'bug,urgent'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_without_openhands_label',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {'items': []},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_empty_changelog_items',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_empty_changelog',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
},
|
||||
False,
|
||||
id='issue_updated_without_changelog',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'openhands'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='comment_created_event_with_label',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'issue_deleted',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'openhands'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='unsupported_event_type',
|
||||
),
|
||||
pytest.param(
|
||||
{},
|
||||
False,
|
||||
id='empty_payload',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'status', 'toString': 'In Progress'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_non_label_field',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'fromString': 'openhands'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_fromString_instead_of_toString',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [
|
||||
{'field': 'labels', 'toString': 'not-openhands'},
|
||||
{'field': 'priority', 'toString': 'High'},
|
||||
]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_mixed_fields_no_openhands',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_labeled_ticket(self, payload, expected):
|
||||
"""Test is_labeled_ticket with various payloads."""
|
||||
with patch('integrations.jira.jira_view.OH_LABEL', 'openhands'):
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
result = JiraFactory.is_labeled_ticket(message)
|
||||
assert result == expected
|
||||
|
||||
with pytest.raises(
|
||||
StartingConvoException, match='Failed to create conversation'
|
||||
):
|
||||
await existing_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
@pytest.mark.parametrize(
|
||||
'payload,expected',
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'openhands-exp'}]
|
||||
},
|
||||
},
|
||||
True,
|
||||
id='issue_updated_with_staging_label',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'changelog': {
|
||||
'items': [{'field': 'labels', 'toString': 'openhands'}]
|
||||
},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_with_prod_label_in_staging_env',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_labeled_ticket_staging_labels(self, payload, expected):
|
||||
"""Test is_labeled_ticket with staging environment labels."""
|
||||
with patch('integrations.jira.jira_view.OH_LABEL', 'openhands-exp'):
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
result = JiraFactory.is_labeled_ticket(message)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestJiraFactoryIsTicketComment:
|
||||
"""Parameterized tests for JiraFactory.is_ticket_comment method."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'payload,expected',
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Please fix this @openhands'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_openhands_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': '@openhands please help'},
|
||||
},
|
||||
True,
|
||||
id='comment_starting_with_openhands_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Hello @openhands!'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_openhands_mention_and_punctuation',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': '(@openhands)'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_openhands_in_parentheses',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Hey @OpenHands can you help?'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_case_insensitive_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Hey @OPENHANDS!'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_uppercase_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Regular comment without mention'},
|
||||
},
|
||||
False,
|
||||
id='comment_without_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Hello @openhands-agent!'},
|
||||
},
|
||||
False,
|
||||
id='comment_with_openhands_as_prefix',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'user@openhands.com'},
|
||||
},
|
||||
False,
|
||||
id='comment_with_openhands_in_email',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': ''},
|
||||
},
|
||||
False,
|
||||
id='comment_with_empty_body',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {},
|
||||
},
|
||||
False,
|
||||
id='comment_without_body',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
},
|
||||
False,
|
||||
id='comment_created_without_comment_data',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'comment': {'body': 'Please fix this @openhands'},
|
||||
},
|
||||
False,
|
||||
id='issue_updated_event_with_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'issue_deleted',
|
||||
'comment': {'body': '@openhands'},
|
||||
},
|
||||
False,
|
||||
id='unsupported_event_type_with_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{},
|
||||
False,
|
||||
id='empty_payload',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Multiple @openhands @openhands mentions'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_multiple_mentions',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_ticket_comment(self, payload, expected):
|
||||
"""Test is_ticket_comment with various payloads."""
|
||||
with patch('integrations.jira.jira_view.INLINE_OH_LABEL', '@openhands'), patch(
|
||||
'integrations.jira.jira_view.has_exact_mention'
|
||||
) as mock_has_exact_mention:
|
||||
from integrations.utils import has_exact_mention
|
||||
|
||||
mock_has_exact_mention.side_effect = (
|
||||
lambda text, mention: has_exact_mention(text, mention)
|
||||
)
|
||||
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
result = JiraFactory.is_ticket_comment(message)
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'payload,expected',
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Please fix this @openhands-exp'},
|
||||
},
|
||||
True,
|
||||
id='comment_with_staging_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': '@openhands-exp please help'},
|
||||
},
|
||||
True,
|
||||
id='comment_starting_with_staging_mention',
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {'body': 'Please fix this @openhands'},
|
||||
},
|
||||
False,
|
||||
id='comment_with_prod_mention_in_staging_env',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_ticket_comment_staging_labels(self, payload, expected):
|
||||
"""Test is_ticket_comment with staging environment labels."""
|
||||
with patch(
|
||||
'integrations.jira.jira_view.INLINE_OH_LABEL', '@openhands-exp'
|
||||
), patch(
|
||||
'integrations.jira.jira_view.has_exact_mention'
|
||||
) as mock_has_exact_mention:
|
||||
from integrations.utils import has_exact_mention
|
||||
|
||||
mock_has_exact_mention.side_effect = (
|
||||
lambda text, mention: has_exact_mention(text, mention)
|
||||
)
|
||||
|
||||
message = Message(source=SourceType.JIRA, message={'payload': payload})
|
||||
result = JiraFactory.is_ticket_comment(message)
|
||||
assert result == expected
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Test for ResolverUserContext get_secrets conversion logic.
|
||||
"""Test for ResolverUserContext get_secrets and get_latest_token logic.
|
||||
|
||||
This test focuses on testing the actual ResolverUserContext implementation.
|
||||
"""
|
||||
@@ -12,7 +12,8 @@ from pydantic import SecretStr
|
||||
from enterprise.integrations.resolver_context import ResolverUserContext
|
||||
|
||||
# Import the real classes we want to test
|
||||
from openhands.integrations.provider import CustomSecret
|
||||
from openhands.integrations.provider import CustomSecret, ProviderToken
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
# Import the SDK types we need for testing
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
@@ -131,3 +132,135 @@ def test_custom_to_static_conversion():
|
||||
assert isinstance(static_secret, StaticSecret)
|
||||
assert isinstance(static_secret, SecretSource)
|
||||
assert static_secret.value.get_secret_value() == secret_value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for get_latest_token - ensuring string values are returned
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_provider_tokens(
|
||||
tokens_dict: dict[ProviderType, str],
|
||||
) -> dict[ProviderType, ProviderToken]:
|
||||
"""Helper to create provider tokens dictionary."""
|
||||
return {
|
||||
provider_type: ProviderToken(token=SecretStr(token_value))
|
||||
for provider_type, token_value in tokens_dict.items()
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_returns_string(resolver_context, mock_saas_user_auth):
|
||||
"""Test that get_latest_token returns a string, not a ProviderToken object."""
|
||||
# Arrange
|
||||
token_value = 'ghp_test_github_token_123'
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_latest_token(ProviderType.GITHUB)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert isinstance(result, str), (
|
||||
f'Expected str, got {type(result).__name__}. '
|
||||
'get_latest_token must return a string for StaticSecret compatibility.'
|
||||
)
|
||||
assert result == token_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_returns_string_for_multiple_providers(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_latest_token returns strings for all provider types."""
|
||||
# Arrange
|
||||
provider_tokens = create_provider_tokens(
|
||||
{
|
||||
ProviderType.GITHUB: 'ghp_github_token',
|
||||
ProviderType.GITLAB: 'glpat_gitlab_token',
|
||||
ProviderType.BITBUCKET: 'bitbucket_token',
|
||||
}
|
||||
)
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
|
||||
# Act & Assert - verify each provider returns a string
|
||||
for provider_type, expected_token in [
|
||||
(ProviderType.GITHUB, 'ghp_github_token'),
|
||||
(ProviderType.GITLAB, 'glpat_gitlab_token'),
|
||||
(ProviderType.BITBUCKET, 'bitbucket_token'),
|
||||
]:
|
||||
result = await resolver_context.get_latest_token(provider_type)
|
||||
assert isinstance(
|
||||
result, str
|
||||
), f'Expected str for {provider_type.name}, got {type(result).__name__}'
|
||||
assert result == expected_token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_returns_none_for_missing_provider(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_latest_token returns None when provider is not in tokens."""
|
||||
# Arrange - only GitHub token available
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: 'ghp_token'})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
|
||||
# Act - request GitLab token which doesn't exist
|
||||
result = await resolver_context.get_latest_token(ProviderType.GITLAB)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_returns_none_when_no_provider_tokens(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_latest_token returns None when no provider tokens exist."""
|
||||
# Arrange
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=None)
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_latest_token(ProviderType.GITHUB)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_returns_none_for_empty_token(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_latest_token returns None when provider token has no value."""
|
||||
# Arrange - provider exists but token is None
|
||||
provider_tokens = {ProviderType.GITHUB: ProviderToken(token=None)}
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_latest_token(ProviderType.GITHUB)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_token_can_be_used_with_static_secret(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_latest_token result can be used directly with StaticSecret.
|
||||
|
||||
This is a critical integration test to ensure the return value is compatible
|
||||
with how it's used in _setup_secrets_for_git_providers.
|
||||
"""
|
||||
# Arrange
|
||||
token_value = 'ghp_integration_test_token'
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
|
||||
# Act
|
||||
token = await resolver_context.get_latest_token(ProviderType.GITHUB)
|
||||
|
||||
# Assert - this should NOT raise a ValidationError
|
||||
static_secret = StaticSecret(value=token, description='GITHUB authentication token')
|
||||
assert static_secret.get_value() == token_value
|
||||
|
||||
@@ -6,17 +6,18 @@ import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.routes.api_keys import (
|
||||
delete_byor_key_from_litellm,
|
||||
get_llm_api_key_for_byor,
|
||||
verify_byor_key_in_litellm,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
|
||||
class TestVerifyByorKeyInLitellm:
|
||||
"""Test the verify_byor_key_in_litellm function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_valid_key_returns_true(self, mock_client_class):
|
||||
"""Test that a valid key (200 response) returns True."""
|
||||
# Arrange
|
||||
@@ -32,7 +33,7 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
@@ -42,8 +43,8 @@ class TestVerifyByorKeyInLitellm:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_invalid_key_401_returns_false(self, mock_client_class):
|
||||
"""Test that an invalid key (401 response) returns False."""
|
||||
# Arrange
|
||||
@@ -58,14 +59,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_invalid_key_403_returns_false(self, mock_client_class):
|
||||
"""Test that an invalid key (403 response) returns False."""
|
||||
# Arrange
|
||||
@@ -80,14 +81,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_server_error_returns_false(self, mock_client_class):
|
||||
"""Test that a server error (500) returns False to ensure key validity."""
|
||||
# Arrange
|
||||
@@ -103,14 +104,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_timeout_returns_false(self, mock_client_class):
|
||||
"""Test that a timeout returns False to ensure key validity."""
|
||||
# Arrange
|
||||
@@ -123,14 +124,14 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('server.routes.api_keys.httpx.AsyncClient')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.httpx.AsyncClient')
|
||||
async def test_verify_network_error_returns_false(self, mock_client_class):
|
||||
"""Test that a network error returns False to ensure key validity."""
|
||||
# Arrange
|
||||
@@ -143,13 +144,13 @@ class TestVerifyByorKeyInLitellm:
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', None)
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', None)
|
||||
async def test_verify_missing_api_url_returns_false(self):
|
||||
"""Test that missing LITE_LLM_API_URL returns False."""
|
||||
# Arrange
|
||||
@@ -157,13 +158,13 @@ class TestVerifyByorKeyInLitellm:
|
||||
user_id = 'user-123'
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'https://litellm.example.com')
|
||||
async def test_verify_empty_key_returns_false(self):
|
||||
"""Test that empty key returns False."""
|
||||
# Arrange
|
||||
@@ -171,7 +172,7 @@ class TestVerifyByorKeyInLitellm:
|
||||
user_id = 'user-123'
|
||||
|
||||
# Act
|
||||
result = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
result = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
@@ -205,7 +206,7 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_valid_key_in_database_returns_key(
|
||||
self, mock_get_key, mock_verify_key
|
||||
@@ -229,7 +230,7 @@ class TestGetLlmApiKeyForByor:
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_invalid_key_in_database_regenerates(
|
||||
self,
|
||||
@@ -265,7 +266,7 @@ class TestGetLlmApiKeyForByor:
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_invalid_key_deletion_failure_still_regenerates(
|
||||
self,
|
||||
@@ -328,3 +329,99 @@ class TestGetLlmApiKeyForByor:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestDeleteByorKeyFromLitellm:
|
||||
"""Test the delete_byor_key_from_litellm function with alias cleanup."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
async def test_delete_constructs_alias_from_user(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
"""Test that delete_byor_key_from_litellm constructs key alias from user."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
org_id = 'org-456'
|
||||
byor_key = 'sk-byor-key-to-delete'
|
||||
expected_alias = f'BYOR Key - user {user_id}, org {org_id}'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = org_id
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_delete_key.return_value = None
|
||||
|
||||
# Act
|
||||
result = await delete_byor_key_from_litellm(user_id, byor_key)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
mock_delete_key.assert_called_once_with(byor_key, key_alias=expected_alias)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
async def test_delete_without_user_passes_no_alias(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
"""Test that when user is not found, no alias is passed."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
byor_key = 'sk-byor-key-to-delete'
|
||||
|
||||
mock_get_user.return_value = None
|
||||
mock_delete_key.return_value = None
|
||||
|
||||
# Act
|
||||
result = await delete_byor_key_from_litellm(user_id, byor_key)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_delete_key.assert_called_once_with(byor_key, key_alias=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
async def test_delete_without_org_id_passes_no_alias(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
"""Test that when user has no current_org_id, no alias is passed."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
byor_key = 'sk-byor-key-to-delete'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = None
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_delete_key.return_value = None
|
||||
|
||||
# Act
|
||||
result = await delete_byor_key_from_litellm(user_id, byor_key)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_delete_key.assert_called_once_with(byor_key, key_alias=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.delete_key')
|
||||
@patch('storage.user_store.UserStore.get_user_by_id_async')
|
||||
async def test_delete_returns_false_on_exception(
|
||||
self, mock_get_user, mock_delete_key
|
||||
):
|
||||
"""Test that exceptions during deletion return False."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
byor_key = 'sk-byor-key-to-delete'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = 'org-456'
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_delete_key.side_effect = Exception('LiteLLM API error')
|
||||
|
||||
# Act
|
||||
result = await delete_byor_key_from_litellm(user_id, byor_key)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -19,6 +21,7 @@ from server.routes.integration.jira import (
|
||||
jira_events,
|
||||
unlink_workspace,
|
||||
validate_workspace_integration,
|
||||
verify_jira_signature,
|
||||
)
|
||||
|
||||
|
||||
@@ -61,25 +64,35 @@ def mock_user_auth():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.redis_client', new_callable=MagicMock)
|
||||
async def test_jira_events_invalid_signature(mock_redis, mock_manager, mock_request):
|
||||
async def test_jira_events_invalid_signature(mock_redis, mock_verify, mock_request):
|
||||
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
|
||||
mock_manager.validate_request.return_value = (False, None, None)
|
||||
mock_request.body = AsyncMock(return_value=b'{}')
|
||||
mock_request.json = AsyncMock(return_value={})
|
||||
mock_verify.side_effect = HTTPException(
|
||||
status_code=403, detail="Request signatures didn't match!"
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await jira_events(mock_request, MagicMock())
|
||||
await jira_events(
|
||||
mock_request, MagicMock(), x_hub_signature='sha256=invalid'
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == 'Invalid webhook signature!'
|
||||
assert exc_info.value.detail == "Request signatures didn't match!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.redis_client')
|
||||
async def test_jira_events_duplicate_request(mock_redis, mock_manager, mock_request):
|
||||
async def test_jira_events_duplicate_request(mock_redis, mock_verify, mock_request):
|
||||
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
|
||||
mock_manager.validate_request.return_value = (True, 'sig123', 'payload')
|
||||
mock_request.body = AsyncMock(return_value=b'{}')
|
||||
mock_request.json = AsyncMock(return_value={})
|
||||
mock_verify.return_value = None
|
||||
mock_redis.exists.return_value = True
|
||||
response = await jira_events(mock_request, MagicMock())
|
||||
response = await jira_events(
|
||||
mock_request, MagicMock(), x_hub_signature='sha256=sig123'
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = json.loads(response.body)
|
||||
assert body['success'] is True
|
||||
@@ -348,18 +361,21 @@ class TestJiraLinkCreateValidation:
|
||||
# Test jira_events error scenarios
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.redis_client', new_callable=MagicMock)
|
||||
async def test_jira_events_processing_success(mock_redis, mock_manager, mock_request):
|
||||
async def test_jira_events_processing_success(
|
||||
mock_redis, mock_verify, mock_manager, mock_request
|
||||
):
|
||||
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
|
||||
mock_manager.validate_request.return_value = (
|
||||
True,
|
||||
'sig123',
|
||||
{'test': 'payload'},
|
||||
)
|
||||
mock_request.body = AsyncMock(return_value=b'{"test": "payload"}')
|
||||
mock_request.json = AsyncMock(return_value={'test': 'payload'})
|
||||
mock_verify.return_value = None
|
||||
mock_redis.exists.return_value = False
|
||||
|
||||
background_tasks = MagicMock()
|
||||
response = await jira_events(mock_request, background_tasks)
|
||||
response = await jira_events(
|
||||
mock_request, background_tasks, x_hub_signature='sha256=sig123'
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = json.loads(response.body)
|
||||
@@ -369,19 +385,241 @@ async def test_jira_events_processing_success(mock_redis, mock_manager, mock_req
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.jira_manager', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.verify_jira_signature', new_callable=AsyncMock)
|
||||
@patch('server.routes.integration.jira.redis_client', new_callable=MagicMock)
|
||||
async def test_jira_events_general_exception(mock_redis, mock_manager, mock_request):
|
||||
async def test_jira_events_general_exception(mock_redis, mock_verify, mock_request):
|
||||
with patch('server.routes.integration.jira.JIRA_WEBHOOKS_ENABLED', True):
|
||||
mock_manager.validate_request.side_effect = Exception('Unexpected error')
|
||||
mock_request.body = AsyncMock(side_effect=Exception('Unexpected error'))
|
||||
mock_request.json = AsyncMock(return_value={})
|
||||
|
||||
response = await jira_events(mock_request, MagicMock())
|
||||
response = await jira_events(
|
||||
mock_request, MagicMock(), x_hub_signature='sha256=sig123'
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
body = json.loads(response.body)
|
||||
assert 'Internal server error processing webhook' in body['error']
|
||||
|
||||
|
||||
# Test verify_jira_signature
|
||||
class TestVerifyJiraSignature:
|
||||
"""Test Jira webhook signature verification."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_payload(self):
|
||||
"""Sample webhook payload with comment_created event."""
|
||||
return {
|
||||
'webhookEvent': 'comment_created',
|
||||
'comment': {
|
||||
'body': 'Test comment @openhands',
|
||||
'author': {
|
||||
'emailAddress': 'user@test.com',
|
||||
'displayName': 'Test User',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/user?accountId=123',
|
||||
},
|
||||
},
|
||||
'issue': {
|
||||
'id': '12345',
|
||||
'key': 'TEST-123',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/issue/12345',
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workspace(self):
|
||||
"""Create a mock workspace."""
|
||||
workspace = MagicMock()
|
||||
workspace.id = 1
|
||||
workspace.name = 'test.atlassian.net'
|
||||
workspace.status = 'active'
|
||||
workspace.webhook_secret = 'encrypted_secret'
|
||||
return workspace
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'signature,expected_detail',
|
||||
[
|
||||
(None, 'x-hub-signature header is missing!'),
|
||||
('', 'x-hub-signature header is missing!'),
|
||||
],
|
||||
ids=['signature_none', 'signature_empty'],
|
||||
)
|
||||
async def test_missing_signature(self, signature, expected_detail, sample_payload):
|
||||
"""Test that missing or empty signature raises HTTPException."""
|
||||
body = json.dumps(sample_payload).encode()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await verify_jira_signature(body, signature, sample_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == expected_detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'payload',
|
||||
[
|
||||
{'webhookEvent': 'unknown_event'},
|
||||
{'webhookEvent': 'comment_created', 'comment': {}},
|
||||
{'webhookEvent': 'comment_created', 'comment': {'author': {}}},
|
||||
{'webhookEvent': 'jira:issue_updated', 'user': {}},
|
||||
{},
|
||||
],
|
||||
ids=[
|
||||
'unknown_event',
|
||||
'missing_author',
|
||||
'missing_self_url',
|
||||
'issue_updated_missing_self',
|
||||
'empty_payload',
|
||||
],
|
||||
)
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_workspace_name_not_found(self, mock_manager, payload):
|
||||
"""Test that missing workspace name in payload raises HTTPException."""
|
||||
mock_manager.get_workspace_name_from_payload.return_value = None
|
||||
body = json.dumps(payload).encode()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await verify_jira_signature(body, 'valid_signature', payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == 'Workspace name not found in payload'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_workspace_not_found_in_database(self, mock_manager, sample_payload):
|
||||
"""Test that workspace not found in database raises HTTPException."""
|
||||
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
|
||||
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
body = json.dumps(sample_payload).encode()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await verify_jira_signature(body, 'valid_signature', sample_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == 'Unidentified workspace'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'workspace_status',
|
||||
['inactive', 'disabled', 'pending'],
|
||||
ids=['inactive', 'disabled', 'pending'],
|
||||
)
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_workspace_not_active(
|
||||
self, mock_manager, workspace_status, sample_payload, mock_workspace
|
||||
):
|
||||
"""Test that inactive workspace raises HTTPException."""
|
||||
mock_workspace.status = workspace_status
|
||||
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
|
||||
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
body = json.dumps(sample_payload).encode()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await verify_jira_signature(body, 'valid_signature', sample_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == 'Workspace is inactive'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.token_manager')
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_signature_mismatch(
|
||||
self, mock_manager, mock_token_mgr, sample_payload, mock_workspace
|
||||
):
|
||||
"""Test that signature mismatch raises HTTPException."""
|
||||
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
|
||||
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_token_mgr.decrypt_text.return_value = 'webhook_secret'
|
||||
body = json.dumps(sample_payload).encode()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await verify_jira_signature(body, 'invalid_signature', sample_payload)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "Request signatures didn't match!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.token_manager')
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_valid_signature(
|
||||
self, mock_manager, mock_token_mgr, sample_payload, mock_workspace
|
||||
):
|
||||
"""Test that valid signature passes verification."""
|
||||
webhook_secret = 'webhook_secret'
|
||||
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
|
||||
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_token_mgr.decrypt_text.return_value = webhook_secret
|
||||
|
||||
body = json.dumps(sample_payload).encode()
|
||||
valid_signature = hmac.new(
|
||||
webhook_secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Should not raise any exception
|
||||
result = await verify_jira_signature(body, valid_signature, sample_payload)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'event_type,payload_key,author_key',
|
||||
[
|
||||
('comment_created', 'comment', 'author'),
|
||||
('jira:issue_updated', 'user', None),
|
||||
],
|
||||
ids=['comment_created', 'issue_updated'],
|
||||
)
|
||||
@patch('server.routes.integration.jira.token_manager')
|
||||
@patch('server.routes.integration.jira.jira_manager')
|
||||
async def test_valid_signature_different_events(
|
||||
self,
|
||||
mock_manager,
|
||||
mock_token_mgr,
|
||||
event_type,
|
||||
payload_key,
|
||||
author_key,
|
||||
mock_workspace,
|
||||
):
|
||||
"""Test valid signature verification for different webhook events."""
|
||||
webhook_secret = 'webhook_secret'
|
||||
mock_manager.get_workspace_name_from_payload.return_value = 'test.atlassian.net'
|
||||
mock_manager.integration_store.get_workspace_by_name = AsyncMock(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_token_mgr.decrypt_text.return_value = webhook_secret
|
||||
|
||||
if event_type == 'comment_created':
|
||||
payload = {
|
||||
'webhookEvent': event_type,
|
||||
'comment': {
|
||||
'body': 'Test',
|
||||
'author': {
|
||||
'self': 'https://test.atlassian.net/rest/api/2/user?id=1'
|
||||
},
|
||||
},
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
'webhookEvent': event_type,
|
||||
'user': {'self': 'https://test.atlassian.net/rest/api/2/user?id=1'},
|
||||
}
|
||||
|
||||
body = json.dumps(payload).encode()
|
||||
valid_signature = hmac.new(
|
||||
webhook_secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
result = await verify_jira_signature(body, valid_signature, payload)
|
||||
assert result is None
|
||||
|
||||
|
||||
# Test create_jira_workspace error scenarios
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.get_user_auth')
|
||||
|
||||
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Integration tests for organization API routes.
|
||||
|
||||
Tests the POST /api/organizations endpoint with various scenarios.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Mock database before imports
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
)
|
||||
from server.routes.orgs import org_router
|
||||
from storage.org import Org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
"""Create a test FastAPI app with organization routes and mocked auth."""
|
||||
app = FastAPI()
|
||||
app.include_router(org_router)
|
||||
|
||||
# Override the auth dependency to return a test user
|
||||
def mock_get_openhands_user_id():
|
||||
return 'test-user-123'
|
||||
|
||||
app.dependency_overrides[get_admin_user_id] = mock_get_openhands_user_id
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_success(mock_app):
|
||||
"""
|
||||
GIVEN: Valid organization creation request
|
||||
WHEN: POST /api/organizations is called
|
||||
THEN: Organization is created and returned with 201 status
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name='Test Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
org_version=5,
|
||||
default_llm_model='claude-opus-4-5-20251101',
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
'name': 'Test Organization',
|
||||
'contact_name': 'John Doe',
|
||||
'contact_email': 'john@example.com',
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.create_org_with_owner',
|
||||
AsyncMock(return_value=mock_org),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_org_credits',
|
||||
AsyncMock(return_value=100.0),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/organizations', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
response_data = response.json()
|
||||
assert response_data['name'] == 'Test Organization'
|
||||
assert response_data['contact_name'] == 'John Doe'
|
||||
assert response_data['contact_email'] == 'john@example.com'
|
||||
assert response_data['credits'] == 100.0
|
||||
assert response_data['org_version'] == 5
|
||||
assert response_data['default_llm_model'] == 'claude-opus-4-5-20251101'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_invalid_email(mock_app):
|
||||
"""
|
||||
GIVEN: Request with invalid email format
|
||||
WHEN: POST /api/organizations is called
|
||||
THEN: 422 validation error is returned
|
||||
"""
|
||||
# Arrange
|
||||
request_data = {
|
||||
'name': 'Test Organization',
|
||||
'contact_name': 'John Doe',
|
||||
'contact_email': 'invalid-email', # Missing @
|
||||
}
|
||||
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/organizations', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_empty_name(mock_app):
|
||||
"""
|
||||
GIVEN: Request with empty organization name
|
||||
WHEN: POST /api/organizations is called
|
||||
THEN: 422 validation error is returned
|
||||
"""
|
||||
# Arrange
|
||||
request_data = {
|
||||
'name': '', # Empty string (after whitespace stripping)
|
||||
'contact_name': 'John Doe',
|
||||
'contact_email': 'john@example.com',
|
||||
}
|
||||
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/organizations', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_duplicate_name(mock_app):
|
||||
"""
|
||||
GIVEN: Organization name already exists
|
||||
WHEN: POST /api/organizations is called
|
||||
THEN: 409 Conflict error is returned
|
||||
"""
|
||||
# Arrange
|
||||
request_data = {
|
||||
'name': 'Existing Organization',
|
||||
'contact_name': 'John Doe',
|
||||
'contact_email': 'john@example.com',
|
||||
}
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.create_org_with_owner',
|
||||
AsyncMock(side_effect=OrgNameExistsError('Existing Organization')),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/organizations', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
assert 'already exists' in response.json()['detail'].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_litellm_failure(mock_app):
|
||||
"""
|
||||
GIVEN: LiteLLM integration fails
|
||||
WHEN: POST /api/organizations is called
|
||||
THEN: 500 Internal Server Error is returned
|
||||
"""
|
||||
# Arrange
|
||||
request_data = {
|
||||
'name': 'Test Organization',
|
||||
'contact_name': 'John Doe',
|
||||
'contact_email': 'john@example.com',
|
||||
}
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.create_org_with_owner',
|
||||
AsyncMock(side_effect=LiteLLMIntegrationError('LiteLLM API unavailable')),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/organizations', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert 'LiteLLM integration' in response.json()['detail']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_database_failure(mock_app):
|
||||
"""
|
||||
GIVEN: Database operation fails
|
||||
WHEN: POST /api/organizations is called
|
||||
THEN: 500 Internal Server Error is returned
|
||||
"""
|
||||
# Arrange
|
||||
request_data = {
|
||||
'name': 'Test Organization',
|
||||
'contact_name': 'John Doe',
|
||||
'contact_email': 'john@example.com',
|
||||
}
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.create_org_with_owner',
|
||||
AsyncMock(side_effect=OrgDatabaseError('Database connection failed')),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/organizations', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert 'Failed to create organization' in response.json()['detail']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_unexpected_error(mock_app):
|
||||
"""
|
||||
GIVEN: Unexpected error occurs
|
||||
WHEN: POST /api/organizations is called
|
||||
THEN: 500 Internal Server Error is returned with generic message
|
||||
"""
|
||||
# Arrange
|
||||
request_data = {
|
||||
'name': 'Test Organization',
|
||||
'contact_name': 'John Doe',
|
||||
'contact_email': 'john@example.com',
|
||||
}
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.create_org_with_owner',
|
||||
AsyncMock(side_effect=RuntimeError('Unexpected system error')),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/organizations', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert 'unexpected error' in response.json()['detail'].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_unauthorized():
|
||||
"""
|
||||
GIVEN: User is not authenticated
|
||||
WHEN: POST /api/organizations is called
|
||||
THEN: 401 Unauthorized error is returned
|
||||
"""
|
||||
# Arrange
|
||||
app = FastAPI()
|
||||
app.include_router(org_router)
|
||||
|
||||
# Override to simulate unauthenticated user
|
||||
async def mock_unauthenticated():
|
||||
raise HTTPException(status_code=401, detail='User not authenticated')
|
||||
|
||||
app.dependency_overrides[get_admin_user_id] = mock_unauthenticated
|
||||
|
||||
request_data = {
|
||||
'name': 'Test Organization',
|
||||
'contact_name': 'John Doe',
|
||||
'contact_email': 'john@example.com',
|
||||
}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/organizations', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_forbidden_non_openhands_email():
|
||||
"""
|
||||
GIVEN: User email is not @openhands.dev
|
||||
WHEN: POST /api/organizations is called
|
||||
THEN: 403 Forbidden error is returned
|
||||
"""
|
||||
# Arrange
|
||||
app = FastAPI()
|
||||
app.include_router(org_router)
|
||||
|
||||
# Override to simulate non-@openhands.dev user
|
||||
async def mock_forbidden():
|
||||
raise HTTPException(
|
||||
status_code=403, detail='Access restricted to @openhands.dev users'
|
||||
)
|
||||
|
||||
app.dependency_overrides[get_admin_user_id] = mock_forbidden
|
||||
|
||||
request_data = {
|
||||
'name': 'Test Organization',
|
||||
'contact_name': 'John Doe',
|
||||
'contact_email': 'john@example.com',
|
||||
}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/organizations', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert 'openhands.dev' in response.json()['detail'].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_sensitive_fields_not_exposed(mock_app):
|
||||
"""
|
||||
GIVEN: Organization is created successfully
|
||||
WHEN: Response is returned
|
||||
THEN: Sensitive fields (API keys) are not exposed
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name='Test Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
org_version=5,
|
||||
default_llm_model='claude-opus-4-5-20251101',
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
'name': 'Test Organization',
|
||||
'contact_name': 'John Doe',
|
||||
'contact_email': 'john@example.com',
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.create_org_with_owner',
|
||||
AsyncMock(return_value=mock_org),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_org_credits',
|
||||
AsyncMock(return_value=100.0),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/organizations', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
response_data = response.json()
|
||||
|
||||
# Verify sensitive fields are not in response or are None
|
||||
assert (
|
||||
'default_llm_api_key_for_byor' not in response_data
|
||||
or response_data.get('default_llm_api_key_for_byor') is None
|
||||
)
|
||||
assert (
|
||||
'search_api_key' not in response_data
|
||||
or response_data.get('search_api_key') is None
|
||||
)
|
||||
assert (
|
||||
'sandbox_api_key' not in response_data
|
||||
or response_data.get('sandbox_api_key') is None
|
||||
)
|
||||
@@ -82,7 +82,7 @@ class TestGetUserId:
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
user_id = _get_user_id('mock-conversation-id')
|
||||
assert user_id == 'mock-user-id'
|
||||
assert user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
|
||||
def test_get_user_id_conversation_not_found(self, session_maker):
|
||||
"""Test getting user ID when conversation doesn't exist."""
|
||||
@@ -105,10 +105,12 @@ class TestGetSessionApiKey:
|
||||
return_value=[mock_agent_loop_info]
|
||||
)
|
||||
|
||||
api_key = await _get_session_api_key('user-123', 'conv-456')
|
||||
api_key = await _get_session_api_key(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
|
||||
)
|
||||
assert api_key == 'test-api-key'
|
||||
mock_manager.get_agent_loop_info.assert_called_once_with(
|
||||
'user-123', filter_to_sids={'conv-456'}
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', filter_to_sids={'conv-456'}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -118,7 +120,9 @@ class TestGetSessionApiKey:
|
||||
mock_manager.get_agent_loop_info = AsyncMock(return_value=[])
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
await _get_session_api_key('user-123', 'conv-456')
|
||||
await _get_session_api_key(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
|
||||
)
|
||||
|
||||
|
||||
class TestProcessEvent:
|
||||
@@ -142,10 +146,15 @@ class TestProcessEvent:
|
||||
mock_event = MagicMock()
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
|
||||
mock_file_store.write.assert_called_once_with(
|
||||
'users/user-123/conversations/conv-456/events/event-1.json',
|
||||
'users/5594c7b6-f959-4b81-92e9-b09c206f5081/conversations/conv-456/events/event-1.json',
|
||||
json.dumps(content),
|
||||
)
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
@@ -177,14 +186,19 @@ class TestProcessEvent:
|
||||
)
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
|
||||
mock_file_store.write.assert_called_once()
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
mock_invoke_callbacks.assert_called_once_with('conv-456', mock_event)
|
||||
mock_update_working_seconds.assert_called_once()
|
||||
mock_event_store_class.assert_called_once_with(
|
||||
'conv-456', mock_file_store, 'user-123'
|
||||
'conv-456', mock_file_store, '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -212,7 +226,12 @@ class TestProcessEvent:
|
||||
mock_event.agent_state = 'running' # Set RUNNING state to skip the update
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
|
||||
mock_file_store.write.assert_called_once()
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
|
||||
@@ -0,0 +1,368 @@
|
||||
"""Tests for SaasSQLAppConversationInfoService.
|
||||
|
||||
This module tests the SAAS implementation of SQLAppConversationInfoService,
|
||||
focusing on user isolation, SAAS metadata handling, and multi-tenant functionality.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||
SaasSQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance for testing."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_with_user(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance with a user_id for testing."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111'),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_info() -> AppConversationInfo:
|
||||
"""Create a sample AppConversationInfo for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
sandbox_id='sandbox_123',
|
||||
selected_repository='https://github.com/test/repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Test Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[123, 456],
|
||||
llm_model='gpt-4',
|
||||
metrics=None,
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multiple_conversation_infos() -> list[AppConversationInfo]:
|
||||
"""Create multiple AppConversationInfo instances for testing."""
|
||||
base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
return [
|
||||
AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=None,
|
||||
sandbox_id=f'sandbox_{i}',
|
||||
selected_repository=f'https://github.com/test/repo{i}',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title=f'Test Conversation {i}',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[i * 100],
|
||||
llm_model='gpt-4',
|
||||
metrics=None,
|
||||
created_at=base_time.replace(hour=12 + i),
|
||||
updated_at=base_time.replace(hour=12 + i, minute=30),
|
||||
)
|
||||
for i in range(1, 6) # Create 5 conversations
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user1_context():
|
||||
"""Create user context for user1."""
|
||||
return SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user2_context():
|
||||
"""Create user context for user2."""
|
||||
return SpecifyUserContext(user_id='b2222222-2222-2222-2222-222222222222')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def saas_service_user1(mock_db_session, user1_context):
|
||||
"""Create a SaasSQLAppConversationInfoService instance for user1."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=mock_db_session, user_context=user1_context
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def saas_service_user2(mock_db_session, user2_context):
|
||||
"""Create a SaasSQLAppConversationInfoService instance for user2."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=mock_db_session, user_context=user2_context
|
||||
)
|
||||
|
||||
|
||||
class TestSaasSQLAppConversationInfoService:
|
||||
"""Test suite for SaasSQLAppConversationInfoService."""
|
||||
|
||||
def test_service_initialization(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
user1_context: SpecifyUserContext,
|
||||
):
|
||||
"""Test that the SAAS service is properly initialized."""
|
||||
assert saas_service_user1.user_context == user1_context
|
||||
assert saas_service_user1.db_session is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_context_isolation(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
saas_service_user2: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that different service instances have different user contexts."""
|
||||
user1_id = await saas_service_user1.user_context.get_user_id()
|
||||
user2_id = await saas_service_user2.user_context.get_user_id()
|
||||
|
||||
assert user1_id == 'a1111111-1111-1111-1111-111111111111'
|
||||
assert user2_id == 'b2222222-2222-2222-2222-222222222222'
|
||||
assert user1_id != user2_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_includes_user_filtering(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that _secure_select method includes user filtering."""
|
||||
# This test verifies that the _secure_select method exists and can be called
|
||||
# The actual SQL generation is tested implicitly through integration
|
||||
query = await saas_service_user1._secure_select()
|
||||
assert query is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_info_with_user_id_functionality(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that _to_info_with_user_id properly sets user_id from SAAS metadata."""
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
# Create mock metadata objects
|
||||
stored_metadata = MagicMock(spec=StoredConversationMetadata)
|
||||
stored_metadata.conversation_id = '12345678-1234-5678-1234-567812345678'
|
||||
stored_metadata.parent_conversation_id = None
|
||||
stored_metadata.title = 'Test Conversation'
|
||||
stored_metadata.sandbox_id = 'test-sandbox'
|
||||
stored_metadata.selected_repository = None
|
||||
stored_metadata.selected_branch = None
|
||||
stored_metadata.git_provider = None
|
||||
stored_metadata.trigger = None
|
||||
stored_metadata.pr_number = []
|
||||
stored_metadata.llm_model = None
|
||||
from datetime import datetime, timezone
|
||||
|
||||
stored_metadata.created_at = datetime.now(timezone.utc)
|
||||
stored_metadata.last_updated_at = datetime.now(timezone.utc)
|
||||
stored_metadata.accumulated_cost = 0.0
|
||||
stored_metadata.prompt_tokens = 0
|
||||
stored_metadata.completion_tokens = 0
|
||||
stored_metadata.total_tokens = 0
|
||||
stored_metadata.max_budget_per_task = None
|
||||
stored_metadata.cache_read_tokens = 0
|
||||
stored_metadata.cache_write_tokens = 0
|
||||
stored_metadata.reasoning_tokens = 0
|
||||
stored_metadata.context_window = 0
|
||||
stored_metadata.per_turn_token = 0
|
||||
|
||||
saas_metadata = MagicMock(spec=StoredConversationMetadataSaas)
|
||||
saas_metadata.user_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
saas_metadata.org_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
|
||||
# Test the _to_info_with_user_id method
|
||||
result = saas_service_user1._to_info_with_user_id(
|
||||
stored_metadata, saas_metadata
|
||||
)
|
||||
|
||||
# Verify that the user_id from SAAS metadata is used
|
||||
assert result.created_by_user_id == 'a1111111-1111-1111-1111-111111111111'
|
||||
assert result.title == 'Test Conversation'
|
||||
assert result.sandbox_id == 'test-sandbox'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(
|
||||
self,
|
||||
async_session: AsyncSession,
|
||||
multiple_conversation_infos: list[AppConversationInfo],
|
||||
):
|
||||
"""Test that user isolation works correctly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database session execute method to return mock users
|
||||
# This mock intercepts User queries and returns a mock user object
|
||||
# with user_id and org_id the same as the user_id_uuid from the query
|
||||
original_execute = async_session.execute
|
||||
|
||||
async def mock_execute(query):
|
||||
query_str = str(query)
|
||||
|
||||
# Check if this is a User query
|
||||
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
|
||||
# Extract the UUID from the query parameters
|
||||
# The query will have bound parameters, we need to get the UUID value
|
||||
if hasattr(query, 'compile'):
|
||||
try:
|
||||
compiled = query.compile(compile_kwargs={'literal_binds': True})
|
||||
query_with_params = str(compiled)
|
||||
|
||||
# Extract UUID from the query string
|
||||
import re
|
||||
|
||||
# Try both formats: with dashes and without dashes
|
||||
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
||||
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
|
||||
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_with_dashes, query_with_params
|
||||
)
|
||||
if not uuid_match:
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_without_dashes, query_with_params
|
||||
)
|
||||
|
||||
if uuid_match:
|
||||
user_id_str = uuid_match.group(0)
|
||||
# If the UUID doesn't have dashes, add them
|
||||
if len(user_id_str) == 32 and '-' not in user_id_str:
|
||||
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
|
||||
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
|
||||
# Create a mock user with user_id and org_id the same as user_id_uuid
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.id = user_id_uuid
|
||||
mock_user.current_org_id = user_id_uuid
|
||||
|
||||
# Create a mock result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_user
|
||||
return mock_result
|
||||
except Exception:
|
||||
# If there's any error in parsing, fall back to original execute
|
||||
pass
|
||||
|
||||
# For all other queries, use the original execute method
|
||||
return await original_execute(query)
|
||||
|
||||
# Apply the mock
|
||||
async_session.execute = mock_execute
|
||||
|
||||
# Create services for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='a1111111-1111-1111-1111-111111111111'
|
||||
),
|
||||
)
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='b2222222-2222-2222-2222-222222222222'
|
||||
),
|
||||
)
|
||||
|
||||
# Create conversations for different users
|
||||
user1_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User 1 Conversation',
|
||||
)
|
||||
|
||||
user2_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='b2222222-2222-2222-2222-222222222222',
|
||||
sandbox_id='sandbox_user2',
|
||||
title='User 2 Conversation',
|
||||
)
|
||||
|
||||
# Save conversations
|
||||
await user1_service.save_app_conversation_info(user1_info)
|
||||
await user2_service.save_app_conversation_info(user2_info)
|
||||
|
||||
# User 1 should only see their conversation
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 1
|
||||
assert (
|
||||
user1_page.items[0].created_by_user_id
|
||||
== 'a1111111-1111-1111-1111-111111111111'
|
||||
)
|
||||
|
||||
# User 2 should only see their conversation
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 1
|
||||
assert (
|
||||
user2_page.items[0].created_by_user_id
|
||||
== 'b2222222-2222-2222-2222-222222222222'
|
||||
)
|
||||
|
||||
# User 1 should not be able to get user 2's conversation
|
||||
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
||||
assert user2_from_user1 is None
|
||||
|
||||
# User 2 should not be able to get user 1's conversation
|
||||
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
||||
assert user1_from_user2 is None
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
@@ -19,6 +19,14 @@ def mock_session_maker(mock_session):
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock user with org_id."""
|
||||
user = MagicMock()
|
||||
user.current_org_id = 'test-org-123'
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_store(mock_session_maker):
|
||||
return ApiKeyStore(mock_session_maker)
|
||||
@@ -33,11 +41,13 @@ def test_generate_api_key(api_key_store):
|
||||
assert len(key) == len('sk-oh-') + 32
|
||||
|
||||
|
||||
def test_create_api_key(api_key_store, mock_session):
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
"""Test creating an API key."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
name = 'Test Key'
|
||||
mock_get_user.return_value = mock_user
|
||||
api_key_store.generate_api_key = MagicMock(return_value='test-api-key')
|
||||
|
||||
# Execute
|
||||
@@ -45,10 +55,15 @@ def test_create_api_key(api_key_store, mock_session):
|
||||
|
||||
# Verify
|
||||
assert result == 'test-api-key'
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
api_key_store.generate_api_key.assert_called_once()
|
||||
|
||||
# Verify the ApiKey was created with the correct org_id
|
||||
added_api_key = mock_session.add.call_args[0][0]
|
||||
assert added_api_key.org_id == mock_user.current_org_id
|
||||
|
||||
|
||||
def test_validate_api_key_valid(api_key_store, mock_session):
|
||||
"""Test validating a valid API key."""
|
||||
@@ -204,10 +219,12 @@ def test_delete_api_key_by_id(api_key_store, mock_session):
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_list_api_keys(api_key_store, mock_session):
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
"""Test listing API keys for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
now = datetime.now(UTC)
|
||||
mock_key1 = MagicMock()
|
||||
mock_key1.id = 1
|
||||
@@ -223,15 +240,17 @@ def test_list_api_keys(api_key_store, mock_session):
|
||||
mock_key2.last_used_at = None
|
||||
mock_key2.expires_at = None
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
mock_key1,
|
||||
mock_key2,
|
||||
]
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_key1, mock_key2]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.list_api_keys(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert len(result) == 2
|
||||
assert result[0]['id'] == 1
|
||||
assert result[0]['name'] == 'Key 1'
|
||||
@@ -244,3 +263,59 @@ def test_list_api_keys(api_key_store, mock_session):
|
||||
assert result[1]['created_at'] == now
|
||||
assert result[1]['last_used_at'] is None
|
||||
assert result[1]['expires_at'] is None
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
"""Test retrieving MCP API key for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
mock_mcp_key = MagicMock()
|
||||
mock_mcp_key.name = 'MCP_API_KEY'
|
||||
mock_mcp_key.key = 'mcp-test-key'
|
||||
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result == 'mcp-test-key'
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key_not_found(
|
||||
mock_get_user, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test retrieving MCP API key when none exists."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result is None
|
||||
|
||||
@@ -130,6 +130,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -144,6 +145,15 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = None
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = False
|
||||
|
||||
@@ -165,20 +175,19 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.set_response_cookie') as mock_set_cookie,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
):
|
||||
# Mock the session and query results
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Mock user settings with accepted_tos
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -228,6 +237,7 @@ async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -243,6 +253,13 @@ async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -267,6 +284,7 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.email.verify_email', mock_verify_email),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -282,6 +300,13 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_verifier.is_active.return_value = False
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -309,20 +334,20 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
),
|
||||
patch('server.routes.auth.KEYCLOAK_REALM_NAME', 'test-realm'),
|
||||
patch('server.routes.auth.KEYCLOAK_CLIENT_ID', 'test-client'),
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
):
|
||||
# Mock the session and query results
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Mock user settings with accepted_tos
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
@@ -535,6 +560,7 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -549,6 +575,14 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
)
|
||||
mock_token_manager.disable_keycloak_user = AsyncMock()
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
@@ -576,6 +610,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -602,6 +637,15 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -629,6 +673,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -655,6 +700,15 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
@@ -680,6 +734,7 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -706,6 +761,16 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -725,6 +790,7 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
"""Test keycloak_callback when duplicate email is detected."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
@@ -741,6 +807,13 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -761,6 +834,7 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
"""Test keycloak_callback when duplicate is detected but deletion fails."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
@@ -777,6 +851,13 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
|
||||
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
code='test_code', state='test_state', request=mock_request
|
||||
@@ -796,6 +877,7 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
@@ -825,6 +907,14 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -846,6 +936,7 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
@@ -873,6 +964,14 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -898,6 +997,7 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
@@ -924,6 +1024,14 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1043,6 +1151,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1071,6 +1180,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1112,6 +1229,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -1127,6 +1245,13 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
@@ -1172,6 +1297,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1200,6 +1326,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1250,6 +1384,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1278,6 +1413,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1325,6 +1468,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1353,6 +1497,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1399,6 +1551,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1427,6 +1580,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1470,6 +1631,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1498,6 +1660,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1527,6 +1697,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1555,6 +1726,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1590,6 +1769,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.auth.logger') as mock_logger,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
@@ -1618,6 +1798,14 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
@@ -1665,6 +1853,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.logger') as mock_logger,
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -1680,6 +1869,13 @@ class TestKeycloakCallbackRecaptcha:
|
||||
return_value=False
|
||||
)
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -5,22 +6,21 @@ import pytest
|
||||
import stripe
|
||||
from fastapi import HTTPException, Request, status
|
||||
from httpx import Response
|
||||
from integrations.stripe_service import has_payment_method
|
||||
from server.routes import billing
|
||||
from server.routes.billing import (
|
||||
CreateBillingSessionResponse,
|
||||
CreateCheckoutSessionRequest,
|
||||
GetCreditsResponse,
|
||||
cancel_callback,
|
||||
cancel_subscription,
|
||||
create_checkout_session,
|
||||
create_subscription_checkout_session,
|
||||
create_customer_setup_session,
|
||||
get_credits,
|
||||
has_payment_method,
|
||||
success_callback,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.stripe_customer import Base as StripeCustomerBase
|
||||
|
||||
|
||||
@@ -78,28 +78,32 @@ def mock_subscription_request():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_lite_llm_error():
|
||||
mock_response = Response(
|
||||
status_code=500, json={'error': 'Internal Server Error'}, request=MagicMock()
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
with patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'):
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_credits('mock_user')
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert (
|
||||
exc_info.value.detail
|
||||
== 'Failed to retrieve credit balance from billing service'
|
||||
)
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await get_credits('mock_user')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_success():
|
||||
mock_response = Response(
|
||||
status_code=200,
|
||||
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
|
||||
json={
|
||||
'user_info': {
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
}
|
||||
},
|
||||
request=MagicMock(),
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
@@ -108,24 +112,23 @@ async def test_get_credits_success():
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch('httpx.AsyncClient', return_value=mock_client),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
},
|
||||
),
|
||||
):
|
||||
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.first.return_value = MagicMock(
|
||||
billing_margin=4
|
||||
)
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
result = await get_credits('mock_user')
|
||||
|
||||
result = await get_credits('mock_user')
|
||||
|
||||
assert isinstance(result, GetCreditsResponse)
|
||||
assert result.credits == Decimal(
|
||||
'74.50'
|
||||
) # 100.00 - 25.50 = 74.50 (no billing margin applied)
|
||||
mock_client.__aenter__.return_value.get.assert_called_once_with(
|
||||
'https://llm-proxy.app.all-hands.dev/user/info?user_id=mock_user',
|
||||
headers={'x-goog-api-key': None},
|
||||
)
|
||||
assert isinstance(result, GetCreditsResponse)
|
||||
assert result.credits == Decimal('74.50') # 100.00 - 25.50 = 74.50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -138,6 +141,9 @@ async def test_create_checkout_session_stripe_error(
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = uuid.uuid4()
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
pytest.raises(Exception, match='Stripe API Error'),
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
@@ -149,6 +155,10 @@ async def test_create_checkout_session_stripe_error(
|
||||
AsyncMock(side_effect=Exception('Stripe API Error')),
|
||||
),
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
@@ -174,6 +184,10 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org_id = uuid.uuid4()
|
||||
mock_org.id = mock_org_id
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
patch(
|
||||
@@ -182,6 +196,10 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
@@ -253,7 +271,6 @@ async def test_success_callback_stripe_incomplete():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
@@ -281,44 +298,34 @@ async def test_success_callback_success():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
mock_lite_llm_response = Response(
|
||||
status_code=200,
|
||||
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
|
||||
request=MagicMock(),
|
||||
)
|
||||
mock_lite_llm_update_response = Response(
|
||||
status_code=200, json={}, request=MagicMock()
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
},
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
|
||||
) as mock_update_budget,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_user_settings = MagicMock(billing_margin=None)
|
||||
mock_db_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_user_settings
|
||||
)
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete',
|
||||
amount_subtotal=2500,
|
||||
status='complete', amount_subtotal=2500, customer='mock_customer_id'
|
||||
) # $25.00 in cents
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__.return_value.get.return_value = (
|
||||
mock_lite_llm_response
|
||||
)
|
||||
mock_client_instance.__aenter__.return_value.post.return_value = (
|
||||
mock_lite_llm_update_response
|
||||
)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
response = await success_callback('test_session_id', mock_request)
|
||||
|
||||
assert response.status_code == 302
|
||||
@@ -328,18 +335,14 @@ async def test_success_callback_success():
|
||||
)
|
||||
|
||||
# Verify LiteLLM API calls
|
||||
mock_client_instance.__aenter__.return_value.get.assert_called_once()
|
||||
mock_client_instance.__aenter__.return_value.post.assert_called_once_with(
|
||||
'https://llm-proxy.app.all-hands.dev/user/update',
|
||||
headers={'x-goog-api-key': None},
|
||||
json={
|
||||
'user_id': 'mock_user',
|
||||
'max_budget': 125,
|
||||
}, # 100 + (25.00 from Stripe)
|
||||
mock_update_budget.assert_called_once_with(
|
||||
'mock_org_id',
|
||||
125.0, # 100 + (25.00 from Stripe)
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'completed'
|
||||
assert mock_billing_session.price == 25.0
|
||||
mock_db_session.merge.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@@ -353,27 +356,28 @@ async def test_success_callback_lite_llm_error():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete', amount_total=2500
|
||||
status='complete', amount_subtotal=2500
|
||||
)
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__.return_value.get.side_effect = Exception(
|
||||
'LiteLLM API Error'
|
||||
)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback('test_session_id', mock_request)
|
||||
|
||||
@@ -397,7 +401,8 @@ async def test_cancel_callback_session_not_found():
|
||||
response = await cancel_callback('test_session_id', mock_request)
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
|
||||
response.headers['location']
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify no database updates occurred
|
||||
@@ -423,7 +428,8 @@ async def test_cancel_callback_success():
|
||||
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
|
||||
response.headers['location']
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
@@ -435,314 +441,67 @@ async def test_cancel_callback_success():
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_with_payment_method():
|
||||
"""Test has_payment_method returns True when user has a payment method."""
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[MagicMock()])),
|
||||
) as mock_list_payment_methods,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
MagicMock(stripe_customer_id='cus_test123')
|
||||
)
|
||||
|
||||
mock_has_payment_method = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
mock_has_payment_method,
|
||||
):
|
||||
result = await has_payment_method('mock_user')
|
||||
assert result is True
|
||||
mock_list_payment_methods.assert_called_once_with('cus_test123')
|
||||
mock_has_payment_method.assert_called_once_with('mock_user')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_without_payment_method():
|
||||
"""Test has_payment_method returns False when user has no payment method."""
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
) as mock_list_payment_methods,
|
||||
mock_has_payment_method = AsyncMock(return_value=False)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
mock_has_payment_method,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
MagicMock(stripe_customer_id='cus_test123')
|
||||
)
|
||||
|
||||
mock_has_payment_method.return_value = False
|
||||
result = await has_payment_method('mock_user')
|
||||
assert result is False
|
||||
mock_list_payment_methods.assert_called_once_with('cus_test123')
|
||||
mock_has_payment_method.assert_called_once_with('mock_user')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_success():
|
||||
"""Test successful subscription cancellation."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
async def test_create_customer_setup_session_success():
|
||||
"""Test successful creation of customer setup session."""
|
||||
mock_request = Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
'path': '/api/billing/create-customer-setup-session',
|
||||
'server': ('test.com', 80),
|
||||
'headers': [],
|
||||
}
|
||||
)
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
# Mock Stripe subscription response
|
||||
mock_stripe_subscription = MagicMock()
|
||||
mock_stripe_subscription.cancel_at_period_end = True
|
||||
mock_customer_info = {'customer_id': 'mock-customer-id', 'org_id': 'mock-org-id'}
|
||||
mock_session = MagicMock()
|
||||
mock_session.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_create = AsyncMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Subscription.modify_async',
|
||||
AsyncMock(return_value=mock_stripe_subscription),
|
||||
) as mock_stripe_modify,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function
|
||||
result = await cancel_subscription('test_user')
|
||||
|
||||
# Verify Stripe API was called
|
||||
mock_stripe_modify.assert_called_once_with(
|
||||
'sub_test123', cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Verify database was updated
|
||||
assert mock_subscription_access.cancelled_at is not None
|
||||
mock_session.merge.assert_called_once_with(mock_subscription_access)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify response
|
||||
assert result.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_no_active_subscription():
|
||||
"""Test cancellation when no active subscription exists."""
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Setup mock session with no subscription found
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'No active subscription found' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_missing_stripe_id():
|
||||
"""Test cancellation when subscription has no Stripe ID."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock subscription without Stripe ID
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id=None, # Missing Stripe ID
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'missing Stripe subscription ID' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_stripe_error():
|
||||
"""Test cancellation when Stripe API fails."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Subscription.modify_async',
|
||||
AsyncMock(side_effect=stripe.StripeError('API Error')),
|
||||
'integrations.stripe_service.find_or_create_customer_by_user_id',
|
||||
AsyncMock(return_value=mock_customer_info),
|
||||
),
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to cancel subscription' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_duplicate_prevention(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription when user already has active subscription raises error."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return existing active subscription
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
result = await create_customer_setup_session(mock_request, 'mock_user')
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert (
|
||||
'user already has an active subscription'
|
||||
in str(exc_info.value.detail).lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_allows_after_cancellation(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription is allowed when previous subscription was cancelled."""
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session_obj.id = 'test_session_id'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer',
|
||||
AsyncMock(return_value='cus_test123'),
|
||||
),
|
||||
patch(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(return_value=mock_session_obj),
|
||||
),
|
||||
patch(
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session - the query should return None because cancelled subscriptions are filtered out
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
assert isinstance(result, billing.CreateBillingSessionResponse)
|
||||
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_success_no_existing(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test successful subscription creation when no existing subscription."""
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session_obj.id = 'test_session_id'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer',
|
||||
AsyncMock(return_value='cus_test123'),
|
||||
),
|
||||
patch(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(return_value=mock_session_obj),
|
||||
),
|
||||
patch(
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return no existing subscription
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
# Verify Stripe session creation parameters
|
||||
mock_create.assert_called_once_with(
|
||||
customer='mock-customer-id',
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url='http://test.com/?free_credits=success',
|
||||
cancel_url='http://test.com/',
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
||||
|
||||
@@ -3,6 +3,7 @@ Tests for ConversationCallbackProcessor and ConversationCallback models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from storage.conversation_callback import (
|
||||
@@ -11,6 +12,9 @@ from storage.conversation_callback import (
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
@@ -80,15 +84,22 @@ class TestConversationCallback:
|
||||
"""Create a test conversation metadata record."""
|
||||
with session_maker() as session:
|
||||
metadata = StoredConversationMetadata(
|
||||
conversation_id='test_conversation_123', user_id='test_user_456'
|
||||
conversation_id='test_conversation_123'
|
||||
)
|
||||
metadata_saas = StoredConversationMetadataSaas(
|
||||
conversation_id='test_conversation_123',
|
||||
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
)
|
||||
session.add(metadata)
|
||||
session.add(metadata_saas)
|
||||
session.commit()
|
||||
session.refresh(metadata)
|
||||
yield metadata
|
||||
|
||||
# Cleanup
|
||||
session.delete(metadata)
|
||||
session.delete(metadata_saas)
|
||||
session.commit()
|
||||
|
||||
def test_callback_creation(self, conversation_metadata, session_maker):
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Unit tests for email validation dependency (get_admin_user_id).
|
||||
|
||||
Tests the FastAPI dependency that validates @openhands.dev email domain.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request
|
||||
from server.email_validation import get_admin_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock FastAPI request."""
|
||||
return MagicMock(spec=Request)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_auth():
|
||||
"""Create a mock user auth object."""
|
||||
mock_auth = AsyncMock()
|
||||
mock_auth.get_user_email = AsyncMock()
|
||||
return mock_auth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_success(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: Valid user ID and @openhands.dev email
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: User ID is returned successfully
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test@openhands.dev'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act
|
||||
result = await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
# Assert
|
||||
assert result == user_id
|
||||
mock_user_auth.get_user_email.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_no_user_id(mock_request):
|
||||
"""
|
||||
GIVEN: No user ID provided (None)
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'not authenticated' in exc_info.value.detail.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_no_email(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: User ID provided but email is None
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = None
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'email not available' in exc_info.value.detail.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_invalid_domain(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: User ID and email with non-@openhands.dev domain
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test@external.com'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'openhands.dev' in exc_info.value.detail.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_empty_string_user_id(mock_request):
|
||||
"""
|
||||
GIVEN: Empty string user ID
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = ''
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'not authenticated' in exc_info.value.detail.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_case_sensitivity(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: Email with uppercase @OPENHANDS.DEV domain
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 403 Forbidden is raised (case-sensitive check)
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test@OPENHANDS.DEV'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_subdomain_not_allowed(
|
||||
mock_request, mock_user_auth
|
||||
):
|
||||
"""
|
||||
GIVEN: Email with subdomain like @test.openhands.dev
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test@test.openhands.dev'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_similar_domain_not_allowed(
|
||||
mock_request, mock_user_auth
|
||||
):
|
||||
"""
|
||||
GIVEN: Email with similar but different domain like @openhands.dev.fake.com
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test@openhands.dev.fake.com'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_logs_warning_on_invalid_domain(
|
||||
mock_request, mock_user_auth
|
||||
):
|
||||
"""
|
||||
GIVEN: User with invalid email domain
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: Warning is logged with user_id and email_domain
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
invalid_email = 'test@external.com'
|
||||
mock_user_auth.get_user_email.return_value = invalid_email
|
||||
|
||||
with (
|
||||
patch('server.email_validation.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.email_validation.logger') as mock_logger,
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException):
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
# Verify warning was logged
|
||||
mock_logger.warning.assert_called_once()
|
||||
call_args = mock_logger.warning.call_args
|
||||
assert 'Access denied' in call_args[0][0]
|
||||
assert call_args[1]['extra']['user_id'] == user_id
|
||||
assert call_args[1]['extra']['email_domain'] == 'external.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_with_plus_addressing(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: Email with plus addressing (test+tag@openhands.dev)
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: User ID is returned successfully
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'test+tag@openhands.dev'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act
|
||||
result = await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
# Assert
|
||||
assert result == user_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_with_dots_in_local_part(
|
||||
mock_request, mock_user_auth
|
||||
):
|
||||
"""
|
||||
GIVEN: Email with dots in local part (first.last@openhands.dev)
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: User ID is returned successfully
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = 'first.last@openhands.dev'
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act
|
||||
result = await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
# Assert
|
||||
assert result == user_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_openhands_user_id_empty_email(mock_request, mock_user_auth):
|
||||
"""
|
||||
GIVEN: Empty string email
|
||||
WHEN: get_admin_user_id is called
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
mock_user_auth.get_user_email.return_value = ''
|
||||
|
||||
with patch('server.email_validation.get_user_auth', return_value=mock_user_auth):
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_admin_user_id(mock_request, user_id)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'email not available' in exc_info.value.detail.lower()
|
||||
@@ -1,114 +1,100 @@
|
||||
"""Unit tests for get_user_v1_enabled_setting function."""
|
||||
"""Unit tests for get_user_v1_enabled_setting and is_v1_enabled_for_github_resolver functions."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.utils import get_user_v1_enabled_setting
|
||||
from integrations.github.github_view import (
|
||||
get_user_v1_enabled_setting,
|
||||
is_v1_enabled_for_github_resolver,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_settings():
|
||||
"""Create a mock user settings object."""
|
||||
settings = MagicMock()
|
||||
settings.v1_enabled = True # Default to True, can be overridden in tests
|
||||
return settings
|
||||
def mock_org():
|
||||
"""Create a mock org object."""
|
||||
org = MagicMock()
|
||||
org.v1_enabled = True # Default to True, can be overridden in tests
|
||||
return org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings_store():
|
||||
"""Create a mock settings store."""
|
||||
store = MagicMock()
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock config object."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker():
|
||||
"""Create a mock session maker."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(
|
||||
mock_settings_store, mock_config, mock_session_maker, mock_user_settings
|
||||
):
|
||||
def mock_dependencies(mock_org):
|
||||
"""Fixture that patches all the common dependencies."""
|
||||
# Patch at the source module since SaasSettingsStore is imported inside the function
|
||||
with patch(
|
||||
'storage.saas_settings_store.SaasSettingsStore',
|
||||
return_value=mock_settings_store,
|
||||
) as mock_store_class, patch(
|
||||
'integrations.utils.get_config', return_value=mock_config
|
||||
) as mock_get_config, patch(
|
||||
'integrations.utils.session_maker', mock_session_maker
|
||||
), patch(
|
||||
'integrations.utils.call_sync_from_async',
|
||||
return_value=mock_user_settings,
|
||||
) as mock_call_sync:
|
||||
return_value=mock_org,
|
||||
) as mock_call_sync, patch('integrations.utils.OrgStore') as mock_org_store:
|
||||
yield {
|
||||
'store_class': mock_store_class,
|
||||
'get_config': mock_get_config,
|
||||
'session_maker': mock_session_maker,
|
||||
'call_sync': mock_call_sync,
|
||||
'settings_store': mock_settings_store,
|
||||
'user_settings': mock_user_settings,
|
||||
'org_store': mock_org_store,
|
||||
'org': mock_org,
|
||||
}
|
||||
|
||||
|
||||
class TestGetUserV1EnabledSetting:
|
||||
"""Test cases for get_user_v1_enabled_setting function."""
|
||||
class TestIsV1EnabledForGithubResolver:
|
||||
"""Test cases for is_v1_enabled_for_github_resolver function.
|
||||
|
||||
This function returns True only if BOTH the environment variable
|
||||
ENABLE_V1_GITHUB_RESOLVER is true AND the user's org has v1_enabled=True.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'user_setting_enabled,expected_result',
|
||||
'env_var_enabled,user_setting_enabled,expected_result',
|
||||
[
|
||||
(True, True), # User enabled -> True
|
||||
(False, False), # User disabled -> False
|
||||
(False, True, False), # Env var disabled, user enabled -> False
|
||||
(True, False, False), # Env var enabled, user disabled -> False
|
||||
(True, True, True), # Both enabled -> True
|
||||
(False, False, False), # Both disabled -> False
|
||||
],
|
||||
)
|
||||
async def test_v1_enabled_user_setting(
|
||||
self, mock_dependencies, user_setting_enabled, expected_result
|
||||
async def test_v1_enabled_combinations(
|
||||
self, mock_dependencies, env_var_enabled, user_setting_enabled, expected_result
|
||||
):
|
||||
"""Test that the function returns the user's v1_enabled setting."""
|
||||
mock_dependencies['user_settings'].v1_enabled = user_setting_enabled
|
||||
"""Test all combinations of environment variable and user setting values."""
|
||||
mock_dependencies['org'].v1_enabled = user_setting_enabled
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is expected_result
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_enabled
|
||||
):
|
||||
result = await is_v1_enabled_for_github_resolver('test_user_id')
|
||||
assert result is expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_no_user_id(self):
|
||||
"""Test that the function returns False when no user_id is provided."""
|
||||
result = await get_user_v1_enabled_setting(None)
|
||||
assert result is False
|
||||
@pytest.mark.parametrize(
|
||||
'env_var_value,env_var_bool,expected_result',
|
||||
[
|
||||
('false', False, False), # Environment variable 'false' -> False
|
||||
('true', True, True), # Environment variable 'true' -> True
|
||||
],
|
||||
)
|
||||
async def test_environment_variable_integration(
|
||||
self, mock_dependencies, env_var_value, env_var_bool, expected_result
|
||||
):
|
||||
"""Test that the function properly reads the ENABLE_V1_GITHUB_RESOLVER environment variable."""
|
||||
mock_dependencies['org'].v1_enabled = True
|
||||
|
||||
result = await get_user_v1_enabled_setting('')
|
||||
assert result is False
|
||||
with patch.dict(
|
||||
os.environ, {'ENABLE_V1_GITHUB_RESOLVER': env_var_value}
|
||||
), patch('integrations.utils.os.getenv', return_value=env_var_value), patch(
|
||||
'integrations.github.github_view.ENABLE_V1_GITHUB_RESOLVER', env_var_bool
|
||||
):
|
||||
result = await is_v1_enabled_for_github_resolver('test_user_id')
|
||||
assert result is expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_settings_is_none(self, mock_dependencies):
|
||||
"""Test that the function returns False when settings is None."""
|
||||
mock_dependencies['call_sync'].return_value = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is False
|
||||
class TestGetUserV1EnabledSetting:
|
||||
"""Test cases for get_user_v1_enabled_setting function.
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_v1_enabled_is_none(self, mock_dependencies):
|
||||
"""Test that the function returns False when v1_enabled is None."""
|
||||
mock_dependencies['user_settings'].v1_enabled = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_id')
|
||||
assert result is False
|
||||
This function only returns the user's org v1_enabled setting.
|
||||
It does NOT check the ENABLE_V1_GITHUB_RESOLVER environment variable.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calls_correct_methods(self, mock_dependencies):
|
||||
"""Test that the function calls the correct methods with correct parameters."""
|
||||
mock_dependencies['user_settings'].v1_enabled = True
|
||||
mock_dependencies['org'].v1_enabled = True
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
|
||||
@@ -116,13 +102,38 @@ class TestGetUserV1EnabledSetting:
|
||||
assert result is True
|
||||
|
||||
# Verify correct methods were called with correct parameters
|
||||
mock_dependencies['get_config'].assert_called_once()
|
||||
mock_dependencies['store_class'].assert_called_once_with(
|
||||
user_id='test_user_123',
|
||||
session_maker=mock_dependencies['session_maker'],
|
||||
config=mock_dependencies['get_config'].return_value,
|
||||
)
|
||||
mock_dependencies['call_sync'].assert_called_once_with(
|
||||
mock_dependencies['settings_store'].get_user_settings_by_keycloak_id,
|
||||
mock_dependencies['org_store'].get_current_org_from_keycloak_user_id,
|
||||
'test_user_123',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_setting_true(self, mock_dependencies):
|
||||
"""Test that the function returns True when org.v1_enabled is True."""
|
||||
mock_dependencies['org'].v1_enabled = True
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_setting_false(self, mock_dependencies):
|
||||
"""Test that the function returns False when org.v1_enabled is False."""
|
||||
mock_dependencies['org'].v1_enabled = False
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_org_returns_false(self, mock_dependencies):
|
||||
"""Test that the function returns False when no org is found."""
|
||||
# Mock call_sync_from_async to return None (no org found)
|
||||
mock_dependencies['call_sync'].return_value = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_v1_enabled_none_returns_false(self, mock_dependencies):
|
||||
"""Test that the function returns False when org.v1_enabled is None."""
|
||||
mock_dependencies['org'].v1_enabled = None
|
||||
|
||||
result = await get_user_v1_enabled_setting('test_user_123')
|
||||
assert result is False
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Test that the models are correctly defined.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def test_user_model(session_maker):
|
||||
"""Test that the User model works correctly."""
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test_org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
# Create a test user
|
||||
test_user_id = uuid4()
|
||||
user = User(id=test_user_id, current_org_id=org.id, language='en')
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
# Create org_member relationship
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=1,
|
||||
llm_api_key='test-api-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
|
||||
# Query the user
|
||||
queried_user = session.query(User).filter(User.id == test_user_id).first()
|
||||
assert queried_user is not None
|
||||
assert queried_user.language == 'en'
|
||||
|
||||
# Query the org
|
||||
queried_org = session.query(Org).filter(Org.id == org.id).first()
|
||||
assert queried_org is not None
|
||||
assert queried_org.name == 'test_org'
|
||||
|
||||
# Query the org_member relationship
|
||||
queried_org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org.id, OrgMember.user_id == user.id)
|
||||
.first()
|
||||
)
|
||||
assert queried_org_member is not None
|
||||
assert queried_org_member.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
@@ -0,0 +1,253 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock the database module before importing OrgMemberStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
|
||||
def test_get_org_members(session_maker):
|
||||
# Test getting org_members by org ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user1 = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
user2 = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user1, user2, role])
|
||||
session.flush()
|
||||
|
||||
org_member1 = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user1.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-1',
|
||||
status='active',
|
||||
)
|
||||
org_member2 = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user2.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-2',
|
||||
status='active',
|
||||
)
|
||||
session.add_all([org_member1, org_member2])
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_members = OrgMemberStore.get_org_members(org_id)
|
||||
assert len(org_members) == 2
|
||||
api_keys = [om.llm_api_key.get_secret_value() for om in org_members]
|
||||
assert 'test-key-1' in api_keys
|
||||
assert 'test-key-2' in api_keys
|
||||
|
||||
|
||||
def test_get_user_orgs(session_maker):
|
||||
# Test getting org_members by user ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org1 = Org(name='test-org-1')
|
||||
org2 = Org(name='test-org-2')
|
||||
session.add_all([org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org1.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member1 = OrgMember(
|
||||
org_id=org1.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-1',
|
||||
status='active',
|
||||
)
|
||||
org_member2 = OrgMember(
|
||||
org_id=org2.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-2',
|
||||
status='active',
|
||||
)
|
||||
session.add_all([org_member1, org_member2])
|
||||
session.commit()
|
||||
user_id = user.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_members = OrgMemberStore.get_user_orgs(user_id)
|
||||
assert len(org_members) == 2
|
||||
api_keys = [ou.llm_api_key.get_secret_value() for ou in org_members]
|
||||
assert 'test-key-1' in api_keys
|
||||
assert 'test-key-2' in api_keys
|
||||
|
||||
|
||||
def test_get_org_member(session_maker):
|
||||
# Test getting org_member by org and user ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
|
||||
assert retrieved_org_member is not None
|
||||
assert retrieved_org_member.org_id == org_id
|
||||
assert retrieved_org_member.user_id == user_id
|
||||
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key'
|
||||
|
||||
|
||||
def test_add_user_to_org(session_maker):
|
||||
# Test adding a user to an org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
role_id = role.id
|
||||
|
||||
# Test creation
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_member = OrgMemberStore.add_user_to_org(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
llm_api_key='new-test-key',
|
||||
status='active',
|
||||
)
|
||||
|
||||
assert org_member is not None
|
||||
assert org_member.org_id == org_id
|
||||
assert org_member.user_id == user_id
|
||||
assert org_member.role_id == role_id
|
||||
assert org_member.llm_api_key.get_secret_value() == 'new-test-key'
|
||||
assert org_member.status == 'active'
|
||||
|
||||
|
||||
def test_update_user_role_in_org(session_maker):
|
||||
# Test updating user role in org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role1 = Role(name='admin', rank=1)
|
||||
role2 = Role(name='user', rank=2)
|
||||
session.add_all([user, role1, role2])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role1.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
role2_id = role2.id
|
||||
|
||||
# Test update
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
updated_org_member = OrgMemberStore.update_user_role_in_org(
|
||||
org_id=org_id, user_id=user_id, role_id=role2_id, status='inactive'
|
||||
)
|
||||
|
||||
assert updated_org_member is not None
|
||||
assert updated_org_member.role_id == role2_id
|
||||
assert updated_org_member.status == 'inactive'
|
||||
|
||||
|
||||
def test_update_user_role_in_org_not_found(session_maker):
|
||||
# Test updating org_member that doesn't exist
|
||||
from uuid import uuid4
|
||||
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
updated_org_member = OrgMemberStore.update_user_role_in_org(
|
||||
org_id=uuid4(), user_id=99999, role_id=1
|
||||
)
|
||||
assert updated_org_member is None
|
||||
|
||||
|
||||
def test_remove_user_from_org(session_maker):
|
||||
# Test removing a user from an org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
|
||||
# Test removal
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
result = OrgMemberStore.remove_user_from_org(org_id, user_id)
|
||||
assert result is True
|
||||
|
||||
# Verify it's removed
|
||||
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
|
||||
assert retrieved_org_member is None
|
||||
|
||||
|
||||
def test_remove_user_from_org_not_found(session_maker):
|
||||
# Test removing user from org that doesn't exist
|
||||
from uuid import uuid4
|
||||
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
result = OrgMemberStore.remove_user_from_org(uuid4(), 99999)
|
||||
assert result is False
|
||||
@@ -0,0 +1,564 @@
|
||||
"""
|
||||
Unit tests for OrgService.
|
||||
|
||||
Tests the organization creation workflow with compensation pattern,
|
||||
including LiteLLM integration and cleanup on failures.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Mock the database module before importing OrgService
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
)
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_service import OrgService
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_api():
|
||||
"""Mock LiteLLM API for testing."""
|
||||
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
|
||||
api_url_patch = patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
|
||||
)
|
||||
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
|
||||
client_patch = patch('httpx.AsyncClient')
|
||||
|
||||
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = MagicMock(
|
||||
return_value={
|
||||
'team_id': 'test-team-id',
|
||||
'user_id': 'test-user-id',
|
||||
'key': 'test-api-key',
|
||||
}
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_response
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def owner_role(session_maker):
|
||||
"""Create owner role in database."""
|
||||
with session_maker() as session:
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
return role
|
||||
|
||||
|
||||
def test_validate_name_uniqueness_with_unique_name(session_maker):
|
||||
"""
|
||||
GIVEN: A unique organization name
|
||||
WHEN: validate_name_uniqueness is called
|
||||
THEN: No exception is raised
|
||||
"""
|
||||
# Arrange
|
||||
unique_name = 'unique-org-name'
|
||||
|
||||
# Act & Assert - should not raise
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
OrgService.validate_name_uniqueness(unique_name)
|
||||
|
||||
|
||||
def test_validate_name_uniqueness_with_duplicate_name(session_maker):
|
||||
"""
|
||||
GIVEN: An organization name that already exists
|
||||
WHEN: validate_name_uniqueness is called
|
||||
THEN: OrgNameExistsError is raised
|
||||
"""
|
||||
# Arrange
|
||||
existing_name = 'existing-org'
|
||||
existing_org = Org(name=existing_name)
|
||||
|
||||
# Mock OrgStore.get_org_by_name to return the existing org
|
||||
with patch(
|
||||
'storage.org_service.OrgStore.get_org_by_name',
|
||||
return_value=existing_org,
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNameExistsError) as exc_info:
|
||||
OrgService.validate_name_uniqueness(existing_name)
|
||||
|
||||
assert existing_name in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_with_owner_success(
|
||||
session_maker, owner_role, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Valid organization data and user ID
|
||||
WHEN: create_org_with_owner is called
|
||||
THEN: Organization and owner membership are created successfully
|
||||
"""
|
||||
# Arrange
|
||||
org_name = 'test-org'
|
||||
contact_name = 'John Doe'
|
||||
contact_email = 'john@example.com'
|
||||
user_id = uuid.uuid4()
|
||||
temp_org_id = uuid.uuid4()
|
||||
|
||||
# Create user in database first
|
||||
with session_maker() as session:
|
||||
user = User(id=user_id, current_org_id=temp_org_id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
mock_settings = {'team_id': 'test-team', 'user_id': str(user_id)}
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.UserStore.create_default_settings',
|
||||
AsyncMock(return_value=mock_settings),
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_kwargs_from_settings',
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgMemberStore.get_kwargs_from_settings',
|
||||
return_value={'llm_api_key': 'test-key'},
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.create_org_with_owner(
|
||||
name=org_name,
|
||||
contact_name=contact_name,
|
||||
contact_email=contact_email,
|
||||
user_id=str(user_id),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.name == org_name
|
||||
assert result.contact_name == contact_name
|
||||
assert result.contact_email == contact_email
|
||||
assert result.org_version > 0 # Should be set to ORG_SETTINGS_VERSION
|
||||
assert result.default_llm_model is not None # Should be set
|
||||
|
||||
# Verify organization was persisted
|
||||
with session_maker() as session:
|
||||
persisted_org = session.get(Org, result.id)
|
||||
assert persisted_org is not None
|
||||
assert persisted_org.name == org_name
|
||||
|
||||
# Verify owner membership was created
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter_by(org_id=result.id, user_id=user_id)
|
||||
.first()
|
||||
)
|
||||
assert org_member is not None
|
||||
assert org_member.role_id == 1 # owner role id
|
||||
assert org_member.status == 'active'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_with_owner_duplicate_name(
|
||||
session_maker, owner_role, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: An organization name that already exists
|
||||
WHEN: create_org_with_owner is called
|
||||
THEN: OrgNameExistsError is raised without creating LiteLLM resources
|
||||
"""
|
||||
# Arrange
|
||||
existing_name = 'existing-org'
|
||||
with session_maker() as session:
|
||||
org = Org(name=existing_name)
|
||||
session.add(org)
|
||||
session.commit()
|
||||
|
||||
mock_create_settings = AsyncMock()
|
||||
|
||||
# Act & Assert
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.UserStore.create_default_settings',
|
||||
mock_create_settings,
|
||||
),
|
||||
):
|
||||
with pytest.raises(OrgNameExistsError):
|
||||
await OrgService.create_org_with_owner(
|
||||
name=existing_name,
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
user_id='test-user-123',
|
||||
)
|
||||
|
||||
# Verify no LiteLLM API calls were made (early exit)
|
||||
mock_create_settings.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_with_owner_litellm_failure(
|
||||
session_maker, owner_role, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: LiteLLM integration fails
|
||||
WHEN: create_org_with_owner is called
|
||||
THEN: LiteLLMIntegrationError is raised and no database records are created
|
||||
"""
|
||||
# Arrange
|
||||
org_name = 'test-org'
|
||||
|
||||
# Mock LiteLLM failure
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.UserStore.create_default_settings',
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(LiteLLMIntegrationError):
|
||||
await OrgService.create_org_with_owner(
|
||||
name=org_name,
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
user_id='test-user-123',
|
||||
)
|
||||
|
||||
# Verify no organization was created in database
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter_by(name=org_name).first()
|
||||
assert org is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_with_owner_database_failure_triggers_cleanup(
|
||||
session_maker, owner_role, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Database persistence fails after LiteLLM integration succeeds
|
||||
WHEN: create_org_with_owner is called
|
||||
THEN: OrgDatabaseError is raised and LiteLLM cleanup is triggered
|
||||
"""
|
||||
# Arrange
|
||||
org_name = 'test-org'
|
||||
user_id = str(uuid.uuid4())
|
||||
cleanup_called = False
|
||||
|
||||
def mock_cleanup(*args, **kwargs):
|
||||
nonlocal cleanup_called
|
||||
cleanup_called = True
|
||||
return None
|
||||
|
||||
mock_settings = {'team_id': 'test-team', 'user_id': user_id}
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.UserStore.create_default_settings',
|
||||
AsyncMock(return_value=mock_settings),
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_kwargs_from_settings',
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgMemberStore.get_kwargs_from_settings',
|
||||
return_value={'llm_api_key': 'test-key'},
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.persist_org_with_owner',
|
||||
side_effect=Exception('Database connection failed'),
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgService._cleanup_litellm_resources',
|
||||
AsyncMock(side_effect=mock_cleanup),
|
||||
),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgDatabaseError) as exc_info:
|
||||
await OrgService.create_org_with_owner(
|
||||
name=org_name,
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Verify cleanup was called
|
||||
assert cleanup_called
|
||||
assert 'Database connection failed' in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_with_owner_entity_creation_failure_triggers_cleanup(
|
||||
session_maker, owner_role, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Entity creation fails after LiteLLM integration succeeds
|
||||
WHEN: create_org_with_owner is called
|
||||
THEN: OrgDatabaseError is raised and LiteLLM cleanup is triggered
|
||||
"""
|
||||
# Arrange
|
||||
org_name = 'test-org'
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
mock_settings = {'team_id': 'test-team', 'user_id': user_id}
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.UserStore.create_default_settings',
|
||||
AsyncMock(return_value=mock_settings),
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_kwargs_from_settings',
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgMemberStore.get_kwargs_from_settings',
|
||||
return_value={'llm_api_key': 'test-key'},
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgService.get_owner_role',
|
||||
side_effect=Exception('Owner role not found'),
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.LiteLlmManager.delete_team',
|
||||
AsyncMock(),
|
||||
) as mock_delete,
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgDatabaseError) as exc_info:
|
||||
await OrgService.create_org_with_owner(
|
||||
name=org_name,
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Verify cleanup was called
|
||||
mock_delete.assert_called_once()
|
||||
assert 'Owner role not found' in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_litellm_resources_success(mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Valid org_id and user_id
|
||||
WHEN: _cleanup_litellm_resources is called
|
||||
THEN: LiteLLM team is deleted successfully and None is returned
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = 'test-user-123'
|
||||
|
||||
with patch(
|
||||
'storage.org_service.LiteLlmManager.delete_team',
|
||||
AsyncMock(),
|
||||
) as mock_delete:
|
||||
# Act
|
||||
result = await OrgService._cleanup_litellm_resources(org_id, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_delete.assert_called_once_with(str(org_id))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_litellm_resources_failure_returns_exception(mock_litellm_api):
|
||||
"""
|
||||
GIVEN: LiteLLM delete_team fails
|
||||
WHEN: _cleanup_litellm_resources is called
|
||||
THEN: Exception is returned (not raised) for logging
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = 'test-user-123'
|
||||
expected_error = Exception('LiteLLM API unavailable')
|
||||
|
||||
with patch(
|
||||
'storage.org_service.LiteLlmManager.delete_team',
|
||||
AsyncMock(side_effect=expected_error),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService._cleanup_litellm_resources(org_id, user_id)
|
||||
|
||||
# Assert
|
||||
assert result is expected_error
|
||||
assert 'LiteLLM API unavailable' in str(result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_failure_with_cleanup_success():
|
||||
"""
|
||||
GIVEN: Original error and successful cleanup
|
||||
WHEN: _handle_failure_with_cleanup is called
|
||||
THEN: OrgDatabaseError is raised with original error message
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = 'test-user-123'
|
||||
original_error = Exception('Database write failed')
|
||||
|
||||
with patch(
|
||||
'storage.org_service.OrgService._cleanup_litellm_resources',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgDatabaseError) as exc_info:
|
||||
await OrgService._handle_failure_with_cleanup(
|
||||
org_id, user_id, original_error, 'Failed to create organization'
|
||||
)
|
||||
|
||||
assert 'Database write failed' in str(exc_info.value)
|
||||
assert 'Cleanup also failed' not in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_failure_with_cleanup_both_fail():
|
||||
"""
|
||||
GIVEN: Original error and cleanup also fails
|
||||
WHEN: _handle_failure_with_cleanup is called
|
||||
THEN: OrgDatabaseError is raised with both error messages
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = 'test-user-123'
|
||||
original_error = Exception('Database write failed')
|
||||
cleanup_error = Exception('LiteLLM API unavailable')
|
||||
|
||||
with patch(
|
||||
'storage.org_service.OrgService._cleanup_litellm_resources',
|
||||
AsyncMock(return_value=cleanup_error),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgDatabaseError) as exc_info:
|
||||
await OrgService._handle_failure_with_cleanup(
|
||||
org_id, user_id, original_error, 'Failed to create organization'
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert 'Database write failed' in error_message
|
||||
assert 'Cleanup also failed' in error_message
|
||||
assert 'LiteLLM API unavailable' in error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_credits_success(mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Valid user_id and org_id with LiteLLM team info
|
||||
WHEN: get_org_credits is called
|
||||
THEN: Credits are calculated correctly (max_budget - spend)
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
org_id = uuid.uuid4()
|
||||
max_budget = 100.0
|
||||
spend = 25.0
|
||||
|
||||
mock_team_info = {
|
||||
'litellm_budget_table': {'max_budget': max_budget},
|
||||
'spend': spend,
|
||||
}
|
||||
|
||||
with patch(
|
||||
'storage.org_service.LiteLlmManager.get_user_team_info',
|
||||
AsyncMock(return_value=mock_team_info),
|
||||
):
|
||||
# Act
|
||||
credits = await OrgService.get_org_credits(user_id, org_id)
|
||||
|
||||
# Assert
|
||||
assert credits == 75.0 # 100 - 25
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_credits_no_team_info(mock_litellm_api):
|
||||
"""
|
||||
GIVEN: LiteLLM returns no team info
|
||||
WHEN: get_org_credits is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'storage.org_service.LiteLlmManager.get_user_team_info',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
# Act
|
||||
credits = await OrgService.get_org_credits(user_id, org_id)
|
||||
|
||||
# Assert
|
||||
assert credits is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_credits_negative_credits_returns_zero(mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Spend exceeds max_budget
|
||||
WHEN: get_org_credits is called
|
||||
THEN: Zero credits are returned (not negative)
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
org_id = uuid.uuid4()
|
||||
max_budget = 100.0
|
||||
spend = 150.0 # Over budget
|
||||
|
||||
mock_team_info = {
|
||||
'litellm_budget_table': {'max_budget': max_budget},
|
||||
'spend': spend,
|
||||
}
|
||||
|
||||
with patch(
|
||||
'storage.org_service.LiteLlmManager.get_user_team_info',
|
||||
AsyncMock(return_value=mock_team_info),
|
||||
):
|
||||
# Act
|
||||
credits = await OrgService.get_org_credits(user_id, org_id)
|
||||
|
||||
# Assert
|
||||
assert credits == 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_credits_api_failure_returns_none(mock_litellm_api):
|
||||
"""
|
||||
GIVEN: LiteLLM API call fails
|
||||
WHEN: get_org_credits is called
|
||||
THEN: None is returned and error is logged
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'storage.org_service.LiteLlmManager.get_user_team_info',
|
||||
AsyncMock(side_effect=Exception('API error')),
|
||||
):
|
||||
# Act
|
||||
credits = await OrgService.get_org_credits(user_id, org_id)
|
||||
|
||||
# Assert
|
||||
assert credits is None
|
||||
@@ -0,0 +1,417 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
# Mock the database module before importing OrgStore
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_api():
|
||||
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
|
||||
api_url_patch = patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
|
||||
)
|
||||
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
|
||||
client_patch = patch('httpx.AsyncClient')
|
||||
|
||||
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.patch.return_value = (
|
||||
mock_response
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_get_org_by_id(session_maker, mock_litellm_api):
|
||||
# Test getting org by ID
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_org_by_id(org_id)
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.id == org_id
|
||||
assert retrieved_org.name == 'test-org'
|
||||
|
||||
|
||||
def test_get_org_by_id_not_found(session_maker):
|
||||
# Test getting org by ID when it doesn't exist
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
non_existent_id = uuid.uuid4()
|
||||
retrieved_org = OrgStore.get_org_by_id(non_existent_id)
|
||||
assert retrieved_org is None
|
||||
|
||||
|
||||
def test_list_orgs(session_maker, mock_litellm_api):
|
||||
# Test listing all orgs
|
||||
with session_maker() as session:
|
||||
# Create test orgs
|
||||
org1 = Org(name='test-org-1')
|
||||
org2 = Org(name='test-org-2')
|
||||
session.add_all([org1, org2])
|
||||
session.commit()
|
||||
|
||||
# Test listing
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
orgs = OrgStore.list_orgs()
|
||||
assert len(orgs) >= 2
|
||||
org_names = [org.name for org in orgs]
|
||||
assert 'test-org-1' in org_names
|
||||
assert 'test-org-2' in org_names
|
||||
|
||||
|
||||
def test_update_org(session_maker, mock_litellm_api):
|
||||
# Test updating org details
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org', agent='CodeActAgent')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test update
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
updated_org = OrgStore.update_org(
|
||||
org_id=org_id, kwargs={'name': 'updated-org', 'agent': 'PlannerAgent'}
|
||||
)
|
||||
|
||||
assert updated_org is not None
|
||||
assert updated_org.name == 'updated-org'
|
||||
assert updated_org.agent == 'PlannerAgent'
|
||||
|
||||
|
||||
def test_update_org_not_found(session_maker):
|
||||
# Test updating org that doesn't exist
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
from uuid import uuid4
|
||||
|
||||
updated_org = OrgStore.update_org(
|
||||
org_id=uuid4(), kwargs={'name': 'updated-org'}
|
||||
)
|
||||
assert updated_org is None
|
||||
|
||||
|
||||
def test_create_org(session_maker, mock_litellm_api):
|
||||
# Test creating a new org
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
org = OrgStore.create_org(kwargs={'name': 'new-org', 'agent': 'CodeActAgent'})
|
||||
|
||||
assert org is not None
|
||||
assert org.name == 'new-org'
|
||||
assert org.agent == 'CodeActAgent'
|
||||
assert org.id is not None
|
||||
|
||||
|
||||
def test_get_org_by_name(session_maker, mock_litellm_api):
|
||||
# Test getting org by name
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org-by-name')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_org_by_name('test-org-by-name')
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.name == 'test-org-by-name'
|
||||
|
||||
|
||||
def test_get_current_org_from_keycloak_user_id(session_maker, mock_litellm_api):
|
||||
# Test getting current org from user ID
|
||||
test_user_id = uuid.uuid4()
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
from storage.user import User
|
||||
|
||||
user = User(id=test_user_id, current_org_id=org.id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_current_org_from_keycloak_user_id(
|
||||
str(test_user_id)
|
||||
)
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.name == 'test-org'
|
||||
|
||||
|
||||
def test_get_kwargs_from_settings():
|
||||
# Test extracting org kwargs from settings
|
||||
settings = Settings(
|
||||
language='es',
|
||||
agent='CodeActAgent',
|
||||
llm_model='gpt-4',
|
||||
llm_api_key=SecretStr('test-key'),
|
||||
enable_sound_notifications=True,
|
||||
)
|
||||
|
||||
kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
|
||||
# Should only include fields that exist in Org model
|
||||
assert 'agent' in kwargs
|
||||
assert 'default_llm_model' in kwargs
|
||||
assert kwargs['agent'] == 'CodeActAgent'
|
||||
assert kwargs['default_llm_model'] == 'gpt-4'
|
||||
# Should not include fields that don't exist in Org model
|
||||
assert 'language' not in kwargs # language is not in Org model
|
||||
assert 'llm_api_key' not in kwargs
|
||||
assert 'llm_model' not in kwargs
|
||||
assert 'enable_sound_notifications' not in kwargs
|
||||
|
||||
|
||||
def test_persist_org_with_owner_success(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Valid org and org_member entities
|
||||
WHEN: persist_org_with_owner is called
|
||||
THEN: Both entities are persisted in a single transaction and org is returned
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Create user and role first
|
||||
with session_maker() as session:
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
session.add(user)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
|
||||
org = Org(
|
||||
id=org_id,
|
||||
name='Test Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=1,
|
||||
status='active',
|
||||
llm_api_key='test-api-key-123',
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
result = OrgStore.persist_org_with_owner(org, org_member)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.id == org_id
|
||||
assert result.name == 'Test Organization'
|
||||
|
||||
# Verify both entities were persisted
|
||||
with session_maker() as session:
|
||||
persisted_org = session.get(Org, org_id)
|
||||
assert persisted_org is not None
|
||||
assert persisted_org.name == 'Test Organization'
|
||||
|
||||
persisted_member = (
|
||||
session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first()
|
||||
)
|
||||
assert persisted_member is not None
|
||||
assert persisted_member.status == 'active'
|
||||
assert persisted_member.role_id == 1
|
||||
|
||||
|
||||
def test_persist_org_with_owner_returns_refreshed_org(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Valid org and org_member entities
|
||||
WHEN: persist_org_with_owner is called
|
||||
THEN: The returned org is refreshed from database with all fields populated
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
session.add(user)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
|
||||
org = Org(
|
||||
id=org_id,
|
||||
name='Test Org',
|
||||
contact_name='Jane Doe',
|
||||
contact_email='jane@example.com',
|
||||
agent='CodeActAgent',
|
||||
)
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=1,
|
||||
status='active',
|
||||
llm_api_key='test-key',
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
result = OrgStore.persist_org_with_owner(org, org_member)
|
||||
|
||||
# Assert - verify the returned object has database-generated fields
|
||||
assert result.id == org_id
|
||||
assert result.name == 'Test Org'
|
||||
assert result.agent == 'CodeActAgent'
|
||||
# Verify org_version was set by create_org logic (if applicable)
|
||||
assert hasattr(result, 'org_version')
|
||||
|
||||
|
||||
def test_persist_org_with_owner_transaction_atomicity(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Valid org but invalid org_member (missing required field)
|
||||
WHEN: persist_org_with_owner is called
|
||||
THEN: Transaction fails and neither entity is persisted
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
session.add(user)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
|
||||
org = Org(
|
||||
id=org_id,
|
||||
name='Test Org',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
|
||||
# Create invalid org_member (missing required llm_api_key field)
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=1,
|
||||
status='active',
|
||||
# llm_api_key is missing - should cause NOT NULL constraint violation
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
with pytest.raises(IntegrityError): # NOT NULL constraint violation
|
||||
OrgStore.persist_org_with_owner(org, org_member)
|
||||
|
||||
# Verify neither entity was persisted (transaction rolled back)
|
||||
with session_maker() as session:
|
||||
persisted_org = session.get(Org, org_id)
|
||||
assert persisted_org is None
|
||||
|
||||
persisted_member = (
|
||||
session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first()
|
||||
)
|
||||
assert persisted_member is None
|
||||
|
||||
|
||||
def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Org with multiple optional fields populated
|
||||
WHEN: persist_org_with_owner is called
|
||||
THEN: All fields are persisted correctly
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
session.add(user)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
|
||||
org = Org(
|
||||
id=org_id,
|
||||
name='Complex Org',
|
||||
contact_name='Alice Smith',
|
||||
contact_email='alice@example.com',
|
||||
agent='CodeActAgent',
|
||||
default_max_iterations=50,
|
||||
confirmation_mode=True,
|
||||
billing_margin=0.15,
|
||||
)
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=1,
|
||||
status='active',
|
||||
llm_api_key='test-key',
|
||||
max_iterations=100,
|
||||
llm_model='gpt-4',
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
result = OrgStore.persist_org_with_owner(org, org_member)
|
||||
|
||||
# Assert
|
||||
assert result.name == 'Complex Org'
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.default_max_iterations == 50
|
||||
assert result.confirmation_mode is True
|
||||
assert result.billing_margin == 0.15
|
||||
|
||||
# Verify persistence
|
||||
with session_maker() as session:
|
||||
persisted_org = session.get(Org, org_id)
|
||||
assert persisted_org.agent == 'CodeActAgent'
|
||||
assert persisted_org.default_max_iterations == 50
|
||||
assert persisted_org.confirmation_mode is True
|
||||
assert persisted_org.billing_margin == 0.15
|
||||
|
||||
persisted_member = (
|
||||
session.query(OrgMember).filter_by(org_id=org_id, user_id=user_id).first()
|
||||
)
|
||||
assert persisted_member.max_iterations == 100
|
||||
assert persisted_member.llm_model == 'gpt-4'
|
||||
@@ -1,32 +1,15 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.github.github_view import get_user_proactive_conversation_setting
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from integrations.github.github_view import get_user_proactive_conversation_setting
|
||||
from storage.org import Org
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
# Mock the call_sync_from_async function to return the result of the function directly
|
||||
def mock_call_sync_from_async(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
filter = MagicMock()
|
||||
|
||||
# Mock the context manager behavior
|
||||
session.__enter__.return_value = session
|
||||
|
||||
session.query.return_value = query
|
||||
query.filter.return_value = filter
|
||||
|
||||
return session, query, filter
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_no_user_id():
|
||||
"""Test that the function returns False when no user ID is provided."""
|
||||
with patch(
|
||||
@@ -42,75 +25,82 @@ async def test_get_user_proactive_conversation_setting_no_user_id():
|
||||
assert await get_user_proactive_conversation_setting(None) is False
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_not_found(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_not_found():
|
||||
"""Test that False is returned when the user is not found."""
|
||||
session, query, filter = mock_session
|
||||
filter.first.return_value = None
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=None,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_none(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_none():
|
||||
"""Test that False is returned when the user setting is None."""
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = None
|
||||
filter.first.return_value = user_settings
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = None
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_true(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_true():
|
||||
"""Test that True is returned when the user setting is True and the global setting is True."""
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = True
|
||||
filter.first.return_value = user_settings
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = True
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is True
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_false(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_false():
|
||||
"""Test that False is returned when the user setting is False, regardless of global setting."""
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = False
|
||||
filter.first.return_value = user_settings
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = False
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock the database module before importing RoleStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
def test_get_role_by_id(session_maker):
|
||||
# Test getting role by ID
|
||||
with session_maker() as session:
|
||||
# Create a test role
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_id(role_id)
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'admin'
|
||||
|
||||
|
||||
def test_get_role_by_id_not_found(session_maker):
|
||||
# Test getting role by ID when it doesn't exist
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_id(99999)
|
||||
assert retrieved_role is None
|
||||
|
||||
|
||||
def test_get_role_by_name(session_maker):
|
||||
# Test getting role by name
|
||||
with session_maker() as session:
|
||||
# Create a test role
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_name('admin')
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'admin'
|
||||
|
||||
|
||||
def test_get_role_by_name_not_found(session_maker):
|
||||
# Test getting role by name when it doesn't exist
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_name('nonexistent')
|
||||
assert retrieved_role is None
|
||||
|
||||
|
||||
def test_list_roles(session_maker):
|
||||
# Test listing all roles
|
||||
with session_maker() as session:
|
||||
# Create test roles
|
||||
role1 = Role(name='admin', rank=1)
|
||||
role2 = Role(name='user', rank=2)
|
||||
session.add_all([role1, role2])
|
||||
session.commit()
|
||||
|
||||
# Test listing
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
roles = RoleStore.list_roles()
|
||||
assert len(roles) >= 2
|
||||
role_names = [role.name for role in roles]
|
||||
assert 'admin' in role_names
|
||||
assert 'user' in role_names
|
||||
|
||||
|
||||
def test_create_role(session_maker):
|
||||
# Test creating a new role
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
role = RoleStore.create_role(name='moderator', rank=2)
|
||||
|
||||
assert role is not None
|
||||
assert role.name == 'moderator'
|
||||
assert role.rank == 2
|
||||
assert role.id is not None
|
||||
@@ -1,11 +1,16 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_call_sync_from_async():
|
||||
@@ -20,12 +25,22 @@ def mock_call_sync_from_async():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_user_store():
|
||||
"""Mock UserStore.get_user_by_id to return a mock user"""
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.current_org_id = UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
with patch('storage.user_store.UserStore.get_user_by_id', return_value=mock_user):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_and_get(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='my-conversation-id',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='my-repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime.now(UTC),
|
||||
@@ -47,13 +62,13 @@ async def test_save_and_get(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
|
||||
# Create test conversations with different timestamps
|
||||
conversations = [
|
||||
ConversationMetadata(
|
||||
conversation_id=f'conv-{i}',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime(2024, 1, i + 1, tzinfo=UTC),
|
||||
@@ -92,10 +107,10 @@ async def test_search(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_metadata(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='to-delete',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime.now(UTC),
|
||||
@@ -112,17 +127,17 @@ async def test_delete_metadata(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_metadata(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await store.get_metadata('nonexistent-id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='exists-test',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='repo',
|
||||
selected_branch='test-branch',
|
||||
created_at=datetime.now(UTC),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
@@ -19,6 +20,14 @@ def mock_config():
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock user with org_id."""
|
||||
user = MagicMock()
|
||||
user.current_org_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def secrets_store(session_maker, mock_config):
|
||||
return SaasSecretsStore('user-id', session_maker, mock_config)
|
||||
@@ -26,7 +35,14 @@ def secrets_store(session_maker, mock_config):
|
||||
|
||||
class TestSaasSecretsStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_and_load(self, secrets_store):
|
||||
@patch(
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_store_and_load(self, mock_get_user, secrets_store, mock_user):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Create a Secrets object with some test data
|
||||
user_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
@@ -59,7 +75,13 @@ class TestSaasSecretsStore:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encryption_decryption(self, secrets_store):
|
||||
@patch(
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_encryption_decryption(self, mock_get_user, secrets_store, mock_user):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
# Create a Secrets object with sensitive data
|
||||
user_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
@@ -89,6 +111,7 @@ class TestSaasSecretsStore:
|
||||
stored = (
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == 'user-id')
|
||||
.filter(StoredCustomSecrets.org_id == mock_user.current_org_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -152,7 +175,15 @@ class TestSaasSecretsStore:
|
||||
assert await secrets_store.load() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_existing_secrets(self, secrets_store):
|
||||
@patch(
|
||||
'storage.saas_secrets_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
async def test_update_existing_secrets(
|
||||
self, mock_get_user, secrets_store, mock_user
|
||||
):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
# Create and store initial secrets
|
||||
initial_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user