mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4d97ae6c5d | |||
| eeb81ecc49 | |||
| 4d0f2e7a6d | |||
| 2c6d1e97e8 | |||
| 180557265f |
+1
-1
@@ -2,7 +2,7 @@ BACKEND_HOST ?= "127.0.0.1"
|
||||
BACKEND_PORT = 3000
|
||||
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
|
||||
FRONTEND_PORT = 3001
|
||||
OPENHANDS_PATH ?= ".."
|
||||
OPENHANDS_PATH ?= "../../OpenHands"
|
||||
OPENHANDS := $(OPENHANDS_PATH)
|
||||
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
)
|
||||
from integrations.v1_utils import get_saas_user_auth
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
@@ -165,13 +164,8 @@ class GithubManager(Manager):
|
||||
)
|
||||
|
||||
if await self.is_job_requested(message):
|
||||
payload = message.message.get('payload', {})
|
||||
user_id = payload['sender']['id']
|
||||
keycloak_user_id = await self.token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITHUB
|
||||
)
|
||||
github_view = await GithubFactory.create_github_view_from_payload(
|
||||
message, keycloak_user_id
|
||||
message, self.token_manager
|
||||
)
|
||||
logger.info(
|
||||
f'[GitHub] Creating job for {github_view.user_info.username} in {github_view.full_repo_name}#{github_view.issue_number}'
|
||||
@@ -288,15 +282,8 @@ class GithubManager(Manager):
|
||||
f'[Github]: Error summarizing issue solvability: {str(e)}'
|
||||
)
|
||||
|
||||
saas_user_auth = await get_saas_user_auth(
|
||||
github_view.user_info.keycloak_user_id, self.token_manager
|
||||
)
|
||||
|
||||
await github_view.create_new_conversation(
|
||||
self.jinja_env,
|
||||
secret_store.provider_tokens,
|
||||
convo_metadata,
|
||||
saas_user_auth,
|
||||
self.jinja_env, secret_store.provider_tokens, convo_metadata
|
||||
)
|
||||
|
||||
conversation_id = github_view.conversation_id
|
||||
@@ -305,7 +292,14 @@ class GithubManager(Manager):
|
||||
f'[GitHub] Created conversation {conversation_id} for user {user_info.username}'
|
||||
)
|
||||
|
||||
if not github_view.v1:
|
||||
from openhands.server.shared import ConversationStoreImpl, config
|
||||
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, github_view.user_info.keycloak_user_id
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
|
||||
if metadata.conversation_version != 'v1':
|
||||
# Create a GithubCallbackProcessor
|
||||
processor = GithubCallbackProcessor(
|
||||
github_view=github_view,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from github import Github, GithubIntegration
|
||||
@@ -9,7 +8,6 @@ from integrations.github.github_types import (
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from integrations.models import Message
|
||||
from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import (
|
||||
ENABLE_PROACTIVE_CONVERSATION_STARTERS,
|
||||
@@ -19,13 +17,14 @@ from integrations.utils import (
|
||||
has_exact_mention,
|
||||
)
|
||||
from jinja2 import Environment
|
||||
from pydantic.dataclasses import dataclass
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
@@ -35,16 +34,18 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
from openhands.app_server.config import get_app_conversation_service
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.app_server.user.user_models import UserInfo
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.sdk.conversation.secret_source import SecretSource
|
||||
from openhands.server.services.conversation_service import (
|
||||
initialize_conversation,
|
||||
start_conversation,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
@@ -54,6 +55,52 @@ from openhands.utils.async_utils import call_sync_from_async
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
|
||||
|
||||
class GithubUserContext(UserContext):
|
||||
"""User context for GitHub integration that provides user info without web request."""
|
||||
|
||||
def __init__(self, keycloak_user_id: str, git_provider_tokens: PROVIDER_TOKEN_TYPE):
|
||||
self.keycloak_user_id = keycloak_user_id
|
||||
self.git_provider_tokens = git_provider_tokens
|
||||
self.settings_store = SaasSettingsStore(
|
||||
user_id=self.keycloak_user_id,
|
||||
session_maker=session_maker,
|
||||
config=get_config(),
|
||||
)
|
||||
|
||||
self.secrets_store = SaasSecretsStore(
|
||||
self.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return self.keycloak_user_id
|
||||
|
||||
async def get_user_info(self) -> UserInfo:
|
||||
user_settings = await self.settings_store.load()
|
||||
return UserInfo(
|
||||
id=self.keycloak_user_id,
|
||||
**user_settings.model_dump(context={'expose_secrets': True}),
|
||||
)
|
||||
|
||||
async def get_authenticated_git_url(self, repository: str) -> str:
|
||||
# This would need to be implemented based on the git provider tokens
|
||||
# For now, return a basic HTTPS URL
|
||||
return f'https://github.com/{repository}.git'
|
||||
|
||||
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
||||
# Return the appropriate token from git_provider_tokens
|
||||
if provider_type == ProviderType.GITHUB and self.git_provider_tokens:
|
||||
return self.git_provider_tokens.get(ProviderType.GITHUB)
|
||||
return None
|
||||
|
||||
async def get_secrets(self) -> dict[str, SecretSource]:
|
||||
# Return empty dict for now - GitHub integration handles secrets separately
|
||||
user_secrets = await self.secrets_store.load()
|
||||
return dict(user_secrets.custom_secrets) if user_secrets else {}
|
||||
|
||||
async def get_mcp_api_key(self) -> str | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
"""Get the user's proactive conversation setting.
|
||||
|
||||
@@ -72,20 +119,22 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
# Check global setting first - if disabled globally, return False
|
||||
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.enable_proactive_conversation_starters is None:
|
||||
return False
|
||||
|
||||
def _get_setting():
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
return False
|
||||
return bool(org.enable_proactive_conversation_starters)
|
||||
|
||||
return await call_sync_from_async(_get_setting)
|
||||
return settings.enable_proactive_conversation_starters
|
||||
|
||||
|
||||
async def get_user_v1_enabled_setting(user_id: str) -> bool:
|
||||
async def get_user_v1_enabled_setting(user_id: str | None) -> bool:
|
||||
"""Get the user's V1 conversation API setting.
|
||||
|
||||
Args:
|
||||
@@ -94,14 +143,24 @@ async def get_user_v1_enabled_setting(user_id: str) -> bool:
|
||||
Returns:
|
||||
True if V1 conversations are enabled for this user, False otherwise
|
||||
"""
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
|
||||
if not org or org.v1_enabled is None:
|
||||
# If no user ID is provided, we can't check user settings
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
return org.v1_enabled
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.v1_enabled is None:
|
||||
return False
|
||||
|
||||
return settings.v1_enabled
|
||||
|
||||
|
||||
# =================================================
|
||||
@@ -124,7 +183,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
title: str
|
||||
description: str
|
||||
previous_comments: list[Comment]
|
||||
v1: bool
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
github_service = GithubServiceImpl(
|
||||
@@ -159,7 +217,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
issue_body=self.description,
|
||||
previous_comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
@@ -172,19 +229,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
|
||||
v1_enabled = await get_user_v1_enabled_setting(self.user_info.keycloak_user_id)
|
||||
logger.info(
|
||||
f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {v1_enabled}'
|
||||
)
|
||||
if v1_enabled:
|
||||
# Create dummy conversationm metadata
|
||||
# Don't save to conversation store
|
||||
# V1 conversations are stored in a separate table
|
||||
return ConversationMetadata(
|
||||
conversation_id=uuid4().hex, selected_repository=self.full_repo_name
|
||||
)
|
||||
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
@@ -193,7 +237,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
@@ -202,17 +245,14 @@ class GithubIssue(ResolverViewInterface):
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
v1_enabled = await get_user_v1_enabled_setting(self.user_info.keycloak_user_id)
|
||||
logger.info(
|
||||
f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {v1_enabled}'
|
||||
)
|
||||
|
||||
if v1_enabled:
|
||||
try:
|
||||
# Use V1 app conversation service
|
||||
await self._create_v1_conversation(
|
||||
jinja_env, saas_user_auth, conversation_metadata
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
return
|
||||
|
||||
@@ -231,7 +271,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
"""Create conversation using the legacy V0 system."""
|
||||
logger.info('[GitHub V1]: Creating V0 conversation')
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
@@ -253,12 +292,10 @@ class GithubIssue(ResolverViewInterface):
|
||||
async def _create_v1_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
saas_user_auth: UserAuth,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
logger.info('[GitHub V1]: Creating V1 conversation')
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
@@ -289,7 +326,10 @@ class GithubIssue(ResolverViewInterface):
|
||||
)
|
||||
|
||||
# Set up the GitHub user context for the V1 system
|
||||
github_user_context = ResolverUserContext(saas_user_auth=saas_user_auth)
|
||||
github_user_context = GithubUserContext(
|
||||
keycloak_user_id=self.user_info.keycloak_user_id,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
)
|
||||
setattr(injector_state, USER_CONTEXT_ATTR, github_user_context)
|
||||
|
||||
async with get_app_conversation_service(
|
||||
@@ -304,8 +344,6 @@ class GithubIssue(ResolverViewInterface):
|
||||
f'Failed to start V1 conversation: {task.detail}'
|
||||
)
|
||||
|
||||
self.v1 = True
|
||||
|
||||
def _create_github_v1_callback_processor(self):
|
||||
"""Create a V1 callback processor for GitHub integration."""
|
||||
from openhands.app_server.event_callback.github_v1_callback_processor import (
|
||||
@@ -340,6 +378,7 @@ class GithubIssueComment(GithubIssue):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
@@ -376,7 +415,8 @@ class GithubPRComment(GithubIssueComment):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation(
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
selected_repository=self.full_repo_name,
|
||||
@@ -422,6 +462,7 @@ class GithubInlinePRComment(GithubPRComment):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'pr_update_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
pr_number=self.issue_number,
|
||||
pr_title=self.title,
|
||||
@@ -765,7 +806,7 @@ class GithubFactory:
|
||||
|
||||
@staticmethod
|
||||
async def create_github_view_from_payload(
|
||||
message: Message, keycloak_user_id: str
|
||||
message: Message, token_manager: TokenManager
|
||||
) -> ResolverViewInterface:
|
||||
"""Create the appropriate class (GithubIssue or GithubPRComment) based on the payload.
|
||||
Also return metadata about the event (e.g., action type).
|
||||
@@ -775,10 +816,17 @@ class GithubFactory:
|
||||
user_id = payload['sender']['id']
|
||||
username = payload['sender']['login']
|
||||
|
||||
keyloak_user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITHUB
|
||||
)
|
||||
|
||||
if keyloak_user_id is None:
|
||||
logger.warning(f'Got invalid keyloak user id for GitHub User {user_id} ')
|
||||
|
||||
selected_repo = GithubFactory.get_full_repo_name(repo_obj)
|
||||
is_public_repo = not repo_obj.get('private', True)
|
||||
user_info = UserData(
|
||||
user_id=user_id, username=username, keycloak_user_id=keycloak_user_id
|
||||
user_id=user_id, username=username, keycloak_user_id=keyloak_user_id
|
||||
)
|
||||
|
||||
installation_id = message.message['installation']
|
||||
@@ -802,7 +850,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_issue_comment(message):
|
||||
@@ -828,7 +875,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_pr_comment(message):
|
||||
@@ -870,7 +916,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_inline_pr_comment(message):
|
||||
@@ -904,7 +949,6 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.app_server.user.user_models import UserInfo
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.conversation.secret_source import SecretSource, StaticSecret
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class ResolverUserContext(UserContext):
|
||||
"""User context for resolver operations that inherits from UserContext."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
self.saas_user_auth = saas_user_auth
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return await self.saas_user_auth.get_user_id()
|
||||
|
||||
async def get_user_info(self) -> UserInfo:
|
||||
user_settings = await self.saas_user_auth.get_user_settings()
|
||||
user_id = await self.saas_user_auth.get_user_id()
|
||||
if user_settings:
|
||||
return UserInfo(
|
||||
id=user_id,
|
||||
**user_settings.model_dump(context={'expose_secrets': True}),
|
||||
)
|
||||
|
||||
return UserInfo(id=user_id)
|
||||
|
||||
async def get_authenticated_git_url(self, repository: str) -> str:
|
||||
# This would need to be implemented based on the git provider tokens
|
||||
# For now, return a basic HTTPS URL
|
||||
return f'https://github.com/{repository}.git'
|
||||
|
||||
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
||||
# Return the appropriate token from git_provider_tokens
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens:
|
||||
return provider_tokens.get(provider_type)
|
||||
return None
|
||||
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
return await self.saas_user_auth.get_provider_tokens()
|
||||
|
||||
async def get_secrets(self) -> dict[str, SecretSource]:
|
||||
"""Get secrets for the user, including custom secrets."""
|
||||
secrets = await self.saas_user_auth.get_secrets()
|
||||
if secrets:
|
||||
# Convert custom secrets to StaticSecret objects for SDK compatibility
|
||||
# secrets.custom_secrets is of type Mapping[str, CustomSecret]
|
||||
converted_secrets = {}
|
||||
for key, custom_secret in secrets.custom_secrets.items():
|
||||
# Extract the secret value from CustomSecret and convert to StaticSecret
|
||||
secret_value = custom_secret.secret.get_secret_value()
|
||||
converted_secrets[key] = StaticSecret(value=secret_value)
|
||||
return converted_secrets
|
||||
return {}
|
||||
|
||||
async def get_mcp_api_key(self) -> str | None:
|
||||
return await self.saas_user_auth.get_mcp_api_key()
|
||||
@@ -167,7 +167,6 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
'channel_id': self.channel_id,
|
||||
'conversation_id': self.conversation_id,
|
||||
'keycloak_user_id': user_info.keycloak_user_id,
|
||||
'org_id': user_info.org_id,
|
||||
'parent_id': self.thread_ts or self.message_ts,
|
||||
},
|
||||
)
|
||||
@@ -175,7 +174,6 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
conversation_id=self.conversation_id,
|
||||
channel_id=self.channel_id,
|
||||
keycloak_user_id=user_info.keycloak_user_id,
|
||||
org_id=user_info.org_id,
|
||||
parent_id=self.thread_ts
|
||||
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
|
||||
)
|
||||
@@ -306,10 +304,10 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
instructions, _ = self._get_instructions(jinja)
|
||||
user_msg = MessageAction(content=instructions)
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg_action = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg)
|
||||
self.conversation_id, event_to_dict(user_msg_action)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
|
||||
@@ -1,24 +1,19 @@
|
||||
from uuid import UUID
|
||||
|
||||
import stripe
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.org_id == org_id)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer:
|
||||
@@ -26,76 +21,46 @@ async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
|
||||
# If that fails, fallback to stripe
|
||||
search_result = await stripe.Customer.search_async(
|
||||
query=f"metadata['org_id']:'{str(org_id)}'",
|
||||
query=f"metadata['user_id']:'{user_id}'",
|
||||
)
|
||||
data = search_result.data
|
||||
if not data:
|
||||
logger.info(
|
||||
'no_customer_for_org_id',
|
||||
extra={'org_id': str(org_id)},
|
||||
)
|
||||
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
|
||||
return None
|
||||
return data[0].id # type: ignore [attr-defined]
|
||||
|
||||
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
return customer_id
|
||||
|
||||
|
||||
async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
# Get the current org for the user
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
async def find_or_create_customer(user_id: str) -> str:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id:
|
||||
return {'customer_id': customer_id, 'org_id': str(org.id)}
|
||||
logger.info(
|
||||
'creating_customer',
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
return customer_id
|
||||
logger.info('creating_customer', extra={'user_id': user_id})
|
||||
|
||||
# Get the user info from keycloak
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
|
||||
|
||||
# Create the customer in stripe
|
||||
customer = await stripe.Customer.create_async(
|
||||
email=org.contact_email,
|
||||
metadata={'org_id': str(org.id)},
|
||||
email=str(user_info.get('email', '')),
|
||||
metadata={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id=user_id,
|
||||
org_id=org.id,
|
||||
stripe_customer_id=customer.id,
|
||||
)
|
||||
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
|
||||
)
|
||||
return {'customer_id': customer.id, 'org_id': str(org.id)}
|
||||
return customer.id
|
||||
|
||||
|
||||
async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
async def has_payment_method(user_id: str) -> bool:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id is None:
|
||||
return False
|
||||
@@ -106,28 +71,3 @@ async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
|
||||
)
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(session: Session, user_id: str, org: Org):
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ class PRStatus(Enum):
|
||||
class UserData(BaseModel):
|
||||
user_id: int
|
||||
username: str
|
||||
keycloak_user_id: str
|
||||
keycloak_user_id: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
async def get_saas_user_auth(
|
||||
keycloak_user_id: str, token_manager: TokenManager
|
||||
) -> UserAuth:
|
||||
offline_token = await token_manager.load_offline_token(keycloak_user_id)
|
||||
if offline_token is None:
|
||||
logger.info('no_offline_token_found')
|
||||
|
||||
user_auth = SaasUserAuth(
|
||||
user_id=keycloak_user_id,
|
||||
refresh_token=SecretStr(offline_token),
|
||||
)
|
||||
return user_auth
|
||||
@@ -20,8 +20,6 @@ down_revision = '059'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# TODO: decide whether to modify this for orgs or users
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""
|
||||
@@ -30,10 +28,8 @@ def upgrade():
|
||||
|
||||
This replaces the functionality of the removed admin maintenance endpoint.
|
||||
"""
|
||||
|
||||
# Hardcoded value to prevent migration failures when constant is removed from codebase
|
||||
# This migration has already run in production, so we use the value that was current at the time
|
||||
CURRENT_USER_SETTINGS_VERSION = 4
|
||||
# Import here to avoid circular imports
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
|
||||
# Create a connection and bind it to a session
|
||||
connection = op.get_bind()
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
"""create org tables from pgerd schema
|
||||
|
||||
Revision ID: 084
|
||||
Revises: 083
|
||||
Create Date: 2025-01-07 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '084'
|
||||
down_revision: Union[str, None] = '083'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove current settings table
|
||||
op.execute('DROP TABLE IF EXISTS settings')
|
||||
|
||||
# Add already_migrated column to user_settings table
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column(
|
||||
'already_migrated',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
# Create role table
|
||||
op.create_table(
|
||||
'role',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('rank', sa.Integer, nullable=False),
|
||||
sa.UniqueConstraint('name', name='role_name_unique'),
|
||||
)
|
||||
|
||||
# 1. Create default roles
|
||||
op.execute(
|
||||
sa.text("""
|
||||
INSERT INTO role (name, rank) VALUES ('owner', 10), ('admin', 20), ('user', 1000)
|
||||
ON CONFLICT (name) DO NOTHING;
|
||||
""")
|
||||
)
|
||||
|
||||
# Create org table with settings fields
|
||||
op.create_table(
|
||||
'org',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('contact_name', sa.String, nullable=True),
|
||||
sa.Column('contact_email', sa.String, nullable=True),
|
||||
sa.Column('conversation_expiration', sa.Integer, nullable=True),
|
||||
# Settings fields moved to org table
|
||||
sa.Column('agent', sa.String, nullable=True),
|
||||
sa.Column('default_max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('security_analyzer', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'confirmation_mode',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('default_llm_model', sa.String, nullable=True),
|
||||
sa.Column('_default_llm_api_key_for_byor', sa.String, nullable=True),
|
||||
sa.Column('default_llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('remote_runtime_resource_factor', sa.Integer, nullable=True),
|
||||
sa.Column(
|
||||
'enable_default_condenser',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('billing_margin', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_proactive_conversation_starters',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('sandbox_base_container_image', sa.String, nullable=True),
|
||||
sa.Column('sandbox_runtime_container_image', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'org_version', sa.Integer, nullable=False, server_default=sa.text('0')
|
||||
),
|
||||
sa.Column('mcp_config', sa.JSON, nullable=True),
|
||||
sa.Column('_search_api_key', sa.String, nullable=True),
|
||||
sa.Column('_sandbox_api_key', sa.String, nullable=True),
|
||||
sa.Column('max_budget_per_task', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_solvability_analysis',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('v1_enabled', sa.Boolean, nullable=True),
|
||||
sa.UniqueConstraint('name', name='org_name_unique'),
|
||||
)
|
||||
|
||||
# Create user table with user-specific settings fields
|
||||
op.create_table(
|
||||
'user',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('current_org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=True),
|
||||
sa.Column('accepted_tos', sa.DateTime, nullable=True),
|
||||
sa.Column(
|
||||
'enable_sound_notifications',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('language', sa.String, nullable=True),
|
||||
sa.Column('user_consents_to_analytics', sa.Boolean, nullable=True),
|
||||
sa.Column('email', sa.String, nullable=True),
|
||||
sa.Column('email_verified', sa.Boolean, nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
['current_org_id'], ['org.id'], name='current_org_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='user_role_fkey'),
|
||||
)
|
||||
|
||||
# Create org_member table (junction table for many-to-many relationship)
|
||||
op.create_table(
|
||||
'org_member',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('_llm_api_key', sa.String, nullable=False),
|
||||
sa.Column('max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('llm_model', sa.String, nullable=True),
|
||||
sa.Column('_llm_api_key_for_byor', sa.String, nullable=True),
|
||||
sa.Column('llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('status', sa.String, nullable=True),
|
||||
sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='om_org_fkey'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='om_user_fkey'),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='om_role_fkey'),
|
||||
sa.PrimaryKeyConstraint('org_id', 'user_id'),
|
||||
)
|
||||
|
||||
# Add org_id column to existing tables
|
||||
# billing_sessions
|
||||
op.add_column(
|
||||
'billing_sessions',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# Create conversation_metadata_saas table
|
||||
op.create_table(
|
||||
'conversation_metadata_saas',
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'
|
||||
),
|
||||
sa.PrimaryKeyConstraint('conversation_id'),
|
||||
)
|
||||
|
||||
# custom_secrets
|
||||
op.add_column(
|
||||
'custom_secrets',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'custom_secrets_org_fkey', 'custom_secrets', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# api_keys
|
||||
op.add_column(
|
||||
'api_keys', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key('api_keys_org_fkey', 'api_keys', 'org', ['org_id'], ['id'])
|
||||
|
||||
# slack_conversation
|
||||
op.add_column(
|
||||
'slack_conversation',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# slack_users
|
||||
op.add_column(
|
||||
'slack_users', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_users_org_fkey', 'slack_users', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# stripe_customers
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
op.add_column(
|
||||
'stripe_customers',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop already_migrated column from user_settings table
|
||||
op.drop_column('user_settings', 'already_migrated')
|
||||
|
||||
# Drop foreign keys and columns added to existing tables
|
||||
op.drop_constraint(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('stripe_customers', 'org_id')
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.drop_constraint('slack_users_org_fkey', 'slack_users', type_='foreignkey')
|
||||
op.drop_column('slack_users', 'org_id')
|
||||
|
||||
op.drop_constraint(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('slack_conversation', 'org_id')
|
||||
|
||||
op.drop_constraint('api_keys_org_fkey', 'api_keys', type_='foreignkey')
|
||||
op.drop_column('api_keys', 'org_id')
|
||||
|
||||
op.drop_constraint('custom_secrets_org_fkey', 'custom_secrets', type_='foreignkey')
|
||||
op.drop_column('custom_secrets', 'org_id')
|
||||
|
||||
# Drop conversation_metadata_saas table
|
||||
op.drop_table('conversation_metadata_saas')
|
||||
|
||||
op.drop_constraint(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('billing_sessions', 'org_id')
|
||||
|
||||
# Drop tables in reverse order due to foreign key constraints
|
||||
op.drop_table('org_member')
|
||||
op.drop_table('user')
|
||||
op.drop_table('org')
|
||||
op.drop_table('role')
|
||||
Generated
+4829
-5176
File diff suppressed because one or more lines are too long
@@ -4,10 +4,6 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Ensure SAAS configuration is used
|
||||
if not os.getenv('OPENHANDS_CONFIG_CLS'):
|
||||
os.environ['OPENHANDS_CONFIG_CLS'] = 'server.config.SaaSServerConfig'
|
||||
|
||||
import socketio # noqa: E402
|
||||
from fastapi import Request, status # noqa: E402
|
||||
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
|
||||
|
||||
@@ -102,6 +102,7 @@ class SaasUserAuth(UserAuth):
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
settings = await settings_store.load()
|
||||
# If load() returned None, should settings be created?
|
||||
if settings:
|
||||
settings.email = self.email
|
||||
settings.email_verified = self.email_verified
|
||||
|
||||
@@ -9,7 +9,7 @@ from server.logger import logger
|
||||
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -525,18 +525,16 @@ class ClusteredConversationManager(StandaloneConversationManager):
|
||||
)
|
||||
# Look up the user_id from the database
|
||||
with session_maker() as session:
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
StoredConversationMetadata.conversation_id
|
||||
== conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_id = (
|
||||
str(conversation_metadata_saas.user_id)
|
||||
if conversation_metadata_saas
|
||||
else None
|
||||
conversation_metadata.user_id if conversation_metadata else None
|
||||
)
|
||||
# Handle the stopped conversation asynchronously
|
||||
asyncio.create_task(
|
||||
|
||||
@@ -19,8 +19,8 @@ IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
|
||||
|
||||
# Map of user settings versions to their corresponding default LLM models
|
||||
# This ensures that PERSONAL_WORKSPACE_VERSION_TO_MODEL and LITELLM_DEFAULT_MODEL stay in sync
|
||||
PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
# This ensures that CURRENT_USER_SETTINGS_VERSION and LITELLM_DEFAULT_MODEL stay in sync
|
||||
USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
1: 'claude-3-5-sonnet-20241022',
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
@@ -30,17 +30,29 @@ PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
|
||||
# Current user settings version - this should be the latest key in USER_SETTINGS_VERSION_TO_MODEL
|
||||
ORG_SETTINGS_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
PERSONAL_WORKSPACE_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
CURRENT_USER_SETTINGS_VERSION = max(USER_SETTINGS_VERSION_TO_MODEL.keys())
|
||||
|
||||
LITE_LLM_API_URL = os.environ.get(
|
||||
'LITE_LLM_API_URL', 'https://llm-proxy.app.all-hands.dev'
|
||||
)
|
||||
LITE_LLM_TEAM_ID = os.environ.get('LITE_LLM_TEAM_ID', None)
|
||||
LITE_LLM_API_KEY = os.environ.get('LITE_LLM_API_KEY', None)
|
||||
SUBSCRIPTION_PRICE_DATA = {
|
||||
'MONTHLY_SUBSCRIPTION': {
|
||||
'unit_amount': 2000,
|
||||
'currency': 'usd',
|
||||
'product_data': {
|
||||
'name': 'OpenHands Monthly',
|
||||
'tax_code': 'txcd_10000000',
|
||||
},
|
||||
'tax_behavior': 'exclusive',
|
||||
'recurring': {'interval': 'month', 'interval_count': 1},
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
STRIPE_WEBHOOK_SECRET = os.environ.get('STRIPE_WEBHOOK_SECRET', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
SLACK_CLIENT_ID = os.environ.get('SLACK_CLIENT_ID', None)
|
||||
@@ -90,5 +102,5 @@ def get_default_litellm_model():
|
||||
"""
|
||||
if LITELLM_DEFAULT_MODEL:
|
||||
return LITELLM_DEFAULT_MODEL
|
||||
model = PERSONAL_WORKSPACE_VERSION_TO_MODEL[PERSONAL_WORKSPACE_VERSION]
|
||||
model = USER_SETTINGS_VERSION_TO_MODEL[CURRENT_USER_SETTINGS_VERSION]
|
||||
return build_litellm_proxy_model_path(model)
|
||||
|
||||
@@ -44,13 +44,11 @@ class MyProcessor(MaintenanceTaskProcessor):
|
||||
### UserVersionUpgradeProcessor
|
||||
|
||||
Located in `user_version_upgrade_processor.py`, this processor:
|
||||
|
||||
- Handles up to 100 user IDs per task
|
||||
- Upgrades users with `user_version < ORG_SETTINGS_VERSION`
|
||||
- Upgrades users with `user_version < CURRENT_USER_SETTINGS_VERSION`
|
||||
- Uses `SaasSettingsStore.create_default_settings()` for upgrades
|
||||
|
||||
**Usage:**
|
||||
|
||||
```python
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import UserVersionUpgradeProcessor
|
||||
|
||||
@@ -146,26 +144,22 @@ task = create_maintenance_task(
|
||||
## Best Practices
|
||||
|
||||
### Processor Design
|
||||
|
||||
- Keep tasks short-running (under 1 minute)
|
||||
- Handle errors gracefully and return meaningful error information
|
||||
- Use batch processing for large datasets
|
||||
- Include progress information in the return dict
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Always wrap your processor logic in try-catch blocks
|
||||
- Return structured error information
|
||||
- Log important events for debugging
|
||||
|
||||
### Performance
|
||||
|
||||
- Limit batch sizes to avoid long-running tasks
|
||||
- Use database sessions efficiently
|
||||
- Consider memory usage for large datasets
|
||||
|
||||
### Testing
|
||||
|
||||
- Create unit tests for your processors
|
||||
- Test error conditions
|
||||
- Verify the processor serialization/deserialization works correctly
|
||||
@@ -173,7 +167,6 @@ task = create_maintenance_task(
|
||||
## Database Patterns
|
||||
|
||||
The maintenance task system follows the repository's established patterns:
|
||||
|
||||
- Uses `session_maker()` for database operations
|
||||
- Wraps sync database operations in `call_sync_from_async` for async routes
|
||||
- Follows proper SQLAlchemy query patterns
|
||||
@@ -181,18 +174,15 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Integration with Existing Systems
|
||||
|
||||
### User Management
|
||||
|
||||
- Integrates with the existing `UserSettings` model
|
||||
- Uses the current user versioning system (`ORG_SETTINGS_VERSION`)
|
||||
- Uses the current user versioning system (`CURRENT_USER_SETTINGS_VERSION`)
|
||||
- Maintains compatibility with existing user management workflows
|
||||
|
||||
### Authentication
|
||||
|
||||
- Admin endpoints use the existing SaaS authentication system
|
||||
- Requires users to have `admin = True` in their UserSettings
|
||||
|
||||
### Monitoring
|
||||
|
||||
- Tasks are logged with structured information
|
||||
- Status updates are tracked in the database
|
||||
- Error information is preserved for debugging
|
||||
@@ -216,7 +206,6 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements that could be added:
|
||||
|
||||
- Task dependencies and scheduling
|
||||
- Retry mechanisms for failed tasks
|
||||
- Real-time progress updates
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.logger import logger
|
||||
from storage.database import session_maker
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskProcessor
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
|
||||
|
||||
class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
|
||||
"""
|
||||
Processor for upgrading user settings to the current version.
|
||||
|
||||
This processor takes a list of user IDs and upgrades any users
|
||||
whose user_version is less than CURRENT_USER_SETTINGS_VERSION.
|
||||
"""
|
||||
|
||||
user_ids: List[str]
|
||||
|
||||
async def __call__(self, task: MaintenanceTask) -> dict:
|
||||
"""
|
||||
Process user version upgrades for the specified user IDs.
|
||||
|
||||
Args:
|
||||
task: The maintenance task being processed
|
||||
|
||||
Returns:
|
||||
dict: Results containing successful and failed user IDs
|
||||
"""
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:start',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_count': len(self.user_ids),
|
||||
'current_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
if len(self.user_ids) > 100:
|
||||
raise ValueError(
|
||||
f'Too many user IDs: {len(self.user_ids)}. Maximum is 100.'
|
||||
)
|
||||
|
||||
config = load_openhands_config()
|
||||
|
||||
# Track results
|
||||
successful_upgrades = []
|
||||
failed_upgrades = []
|
||||
users_already_current = []
|
||||
|
||||
# Find users that need upgrading
|
||||
with session_maker() as session:
|
||||
users_to_upgrade = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id.in_(self.user_ids),
|
||||
UserSettings.user_version < CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Track users that are already current
|
||||
users_needing_upgrade_ids = {u.keycloak_user_id for u in users_to_upgrade}
|
||||
users_already_current = [
|
||||
uid for uid in self.user_ids if uid not in users_needing_upgrade_ids
|
||||
]
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:found_users',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'users_to_upgrade': len(users_to_upgrade),
|
||||
'users_already_current': len(users_already_current),
|
||||
'total_requested': len(self.user_ids),
|
||||
},
|
||||
)
|
||||
|
||||
# Process each user that needs upgrading
|
||||
for user_settings in users_to_upgrade:
|
||||
user_id = user_settings.keycloak_user_id
|
||||
old_version = user_settings.user_version
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:upgrading_user',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
# Create SaasSettingsStore instance and upgrade
|
||||
settings_store = await SaasSettingsStore.get_instance(config, user_id)
|
||||
await settings_store.create_default_settings(user_settings)
|
||||
|
||||
successful_upgrades.append(
|
||||
{
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:user_upgraded',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_upgrades.append(
|
||||
{'user_id': user_id, 'old_version': old_version, 'error': str(e)}
|
||||
)
|
||||
|
||||
logger.error(
|
||||
'user_version_upgrade_processor:user_upgrade_failed',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
# Create result summary
|
||||
result = {
|
||||
'total_users': len(self.user_ids),
|
||||
'users_already_current': users_already_current,
|
||||
'successful_upgrades': successful_upgrades,
|
||||
'failed_upgrades': failed_upgrades,
|
||||
'summary': (
|
||||
f'Processed {len(self.user_ids)} users: '
|
||||
f'{len(successful_upgrades)} upgraded, '
|
||||
f'{len(users_already_current)} already current, '
|
||||
f'{len(failed_upgrades)} errors'
|
||||
),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:completed',
|
||||
extra={'task_id': task.id, 'result': result},
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
|
||||
@@ -34,7 +36,6 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
|
||||
Returns:
|
||||
A tuple containing the default SSE server configuration and a list of MCP stdio server configurations
|
||||
"""
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
if user_id:
|
||||
|
||||
@@ -1,97 +1,109 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.config import get_config
|
||||
from server.constants import LITE_LLM_API_KEY, LITE_LLM_API_URL
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_store import OrgStore
|
||||
from storage.user_store import UserStore
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
# Helper functions for BYOR API key management
|
||||
async def get_byor_key_from_db(user_id: str) -> str | None:
|
||||
"""Get the BYOR key from the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
def _get_byor_key():
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
if current_org_member.llm_api_key_for_byor:
|
||||
return current_org_member.llm_api_key_for_byor.get_secret_value()
|
||||
|
||||
org = OrgStore.get_org_by_id(current_org_id)
|
||||
if not org:
|
||||
return None
|
||||
return (
|
||||
org.default_llm_api_key_for_byor.get_secret_value()
|
||||
if org.default_llm_api_key_for_byor
|
||||
else None
|
||||
)
|
||||
|
||||
return await call_sync_from_async(_get_byor_key)
|
||||
user_db_settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
if user_db_settings and user_db_settings.llm_api_key_for_byor:
|
||||
return user_db_settings.llm_api_key_for_byor
|
||||
return None
|
||||
|
||||
|
||||
async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
"""Store the BYOR key in the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
def _update_user_settings():
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
current_org_member.llm_api_key_for_byor = key
|
||||
OrgMemberStore.update_org_member(current_org_member)
|
||||
with session_maker() as session:
|
||||
user_db_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
if user_db_settings:
|
||||
user_db_settings.llm_api_key_for_byor = key
|
||||
session.commit()
|
||||
logger.info(
|
||||
'Successfully stored BYOR key in user settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'User settings not found when trying to store BYOR key',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
await call_sync_from_async(_update_user_settings)
|
||||
|
||||
|
||||
async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
key = await LiteLlmManager.generate_key(
|
||||
user_id, user_id, f'BYOR Key - user {user_id}', {'type': 'byor'}
|
||||
)
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
'user_id': user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...' if key and len(key) > 10 else key,
|
||||
'metadata': {'type': 'byor'},
|
||||
'key_alias': f'BYOR Key - user {user_id}',
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...'
|
||||
if key and len(key) > 10
|
||||
else key,
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error generating BYOR key',
|
||||
@@ -102,14 +114,30 @@ async def generate_byor_key(user_id: str) -> str | None:
|
||||
|
||||
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
|
||||
"""Delete the BYOR key from LiteLLM using the key directly."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
await LiteLlmManager.delete_key(byor_key)
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
# Delete the key directly using the key value
|
||||
delete_url = f'{LITE_LLM_API_URL}/key/delete'
|
||||
delete_payload = {'keys': [byor_key]}
|
||||
|
||||
delete_response = await client.post(delete_url, json=delete_payload)
|
||||
delete_response.raise_for_status()
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error deleting BYOR key from LiteLLM',
|
||||
@@ -287,6 +315,15 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
|
||||
|
||||
try:
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='LiteLLM API configuration not found',
|
||||
)
|
||||
|
||||
# Get the existing BYOR key from the database
|
||||
existing_byor_key = await get_byor_key_from_db(user_id)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
@@ -18,12 +17,12 @@ from server.auth.constants import (
|
||||
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import sign_token
|
||||
from server.config import get_config, sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from storage.database import session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
@@ -32,7 +31,6 @@ from openhands.server.services.conversation_service import create_provider_token
|
||||
from openhands.server.shared import config
|
||||
from openhands.server.user_auth import get_access_token
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
@@ -84,8 +82,7 @@ def get_cookie_domain(request: Request) -> str | None:
|
||||
# for now just use the full hostname except for staging stacks.
|
||||
return (
|
||||
None
|
||||
if not request.url.hostname
|
||||
or request.url.hostname.endswith('staging.all-hands.dev')
|
||||
if (request.url.hostname or '').endswith('staging.all-hand.dev')
|
||||
else request.url.hostname
|
||||
)
|
||||
|
||||
@@ -149,21 +146,6 @@ async def keycloak_callback(
|
||||
)
|
||||
|
||||
user_id = user_info['sub']
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
|
||||
if not user:
|
||||
user = await UserStore.create_user(user_id, user_info)
|
||||
|
||||
if not user:
|
||||
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': f'Failed to authenticate user {user_info["preferred_username"]}'
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f'Logging in user {str(user.id)} in org {user.current_org_id}')
|
||||
|
||||
# default to github IDP for now.
|
||||
# TODO: remove default once Keycloak is updated universally with the new attribute.
|
||||
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
|
||||
@@ -238,7 +220,15 @@ async def keycloak_callback(
|
||||
f'&state={state}'
|
||||
)
|
||||
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(user_id)
|
||||
has_accepted_tos = (
|
||||
user_settings is not None and user_settings.accepted_tos is not None
|
||||
)
|
||||
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
encoded_redirect_url = quote(redirect_url, safe='')
|
||||
@@ -356,20 +346,28 @@ async def accept_tos(request: Request):
|
||||
redirect_url = body.get('redirect_url', str(request.base_url))
|
||||
|
||||
# Update user settings with TOS acceptance
|
||||
accepted_tos: datetime = datetime.now(timezone.utc)
|
||||
with session_maker() as session:
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
session.rollback()
|
||||
logger.error('User for {user_id} not found.')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User does not exist'},
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.accepted_tos = datetime.now(timezone.utc)
|
||||
session.merge(user_settings)
|
||||
else:
|
||||
# Create user settings if they don't exist
|
||||
user_settings = UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
accepted_tos=datetime.now(timezone.utc),
|
||||
user_version=0, # This will trigger a migration to the latest version on next load
|
||||
)
|
||||
user.accepted_tos = accepted_tos
|
||||
session.add(user_settings)
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content={'redirect_url': redirect_url}
|
||||
|
||||
@@ -2,23 +2,32 @@
|
||||
import typing
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
import stripe
|
||||
from dateutil.relativedelta import relativedelta # type: ignore
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.config import get_config
|
||||
from server.constants import (
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
STRIPE_API_KEY,
|
||||
STRIPE_WEBHOOK_SECRET,
|
||||
SUBSCRIPTION_PRICE_DATA,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.user_store import UserStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
@@ -55,10 +64,23 @@ def validate_saas_environment(request: Request) -> None:
|
||||
)
|
||||
|
||||
|
||||
class BillingSessionType(Enum):
|
||||
DIRECT_PAYMENT = 'DIRECT_PAYMENT'
|
||||
MONTHLY_SUBSCRIPTION = 'MONTHLY_SUBSCRIPTION'
|
||||
|
||||
|
||||
class GetCreditsResponse(BaseModel):
|
||||
credits: Decimal | None = None
|
||||
|
||||
|
||||
class SubscriptionAccessResponse(BaseModel):
|
||||
start_at: datetime
|
||||
end_at: datetime
|
||||
created_at: datetime
|
||||
cancelled_at: datetime | None = None
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
amount: int
|
||||
|
||||
@@ -89,23 +111,117 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
|
||||
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
|
||||
if not stripe_service.STRIPE_API_KEY:
|
||||
return GetCreditsResponse()
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(user.current_org_id)
|
||||
)
|
||||
# Update to use calculate_credits
|
||||
spend = user_team_info.get('spend', 0)
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
|
||||
credits = max(max_budget - spend, 0)
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
user_json = await _get_litellm_user(client, user_id)
|
||||
credits = calculate_credits(user_json['user_info'])
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current subscription access
|
||||
@billing_router.get('/subscription-access')
|
||||
async def get_subscription_access(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> SubscriptionAccessResponse | None:
|
||||
"""Get details of the currently valid subscription for the user."""
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.first()
|
||||
)
|
||||
if not subscription_access:
|
||||
return None
|
||||
return SubscriptionAccessResponse(
|
||||
start_at=subscription_access.start_at,
|
||||
end_at=subscription_access.end_at,
|
||||
created_at=subscription_access.created_at,
|
||||
cancelled_at=subscription_access.cancelled_at,
|
||||
stripe_subscription_id=subscription_access.stripe_subscription_id,
|
||||
)
|
||||
|
||||
|
||||
# Endpoint to check if a user has entered a payment method into stripe
|
||||
@billing_router.post('/has-payment-method')
|
||||
async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await stripe_service.has_payment_method_by_user_id(user_id)
|
||||
return await stripe_service.has_payment_method(user_id)
|
||||
|
||||
|
||||
# Endpoint to cancel user's subscription
|
||||
@billing_router.post('/cancel-subscription')
|
||||
async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONResponse:
|
||||
"""Cancel user's active subscription at the end of the current billing period."""
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
with session_maker() as session:
|
||||
# Find the user's active subscription
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not already cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription_access:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='No active subscription found',
|
||||
)
|
||||
|
||||
if not subscription_access.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot cancel subscription: missing Stripe subscription ID',
|
||||
)
|
||||
|
||||
try:
|
||||
# Cancel the subscription in Stripe at period end
|
||||
await stripe.Subscription.modify_async(
|
||||
subscription_access.stripe_subscription_id, cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Update local database
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
'end_at': subscription_access.end_at,
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{'status': 'success', 'message': 'Subscription cancelled successfully'}
|
||||
)
|
||||
|
||||
except stripe.StripeError as e:
|
||||
logger.error(
|
||||
'stripe_cancellation_failed',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f'Failed to cancel subscription: {str(e)}',
|
||||
)
|
||||
|
||||
|
||||
# Endpoint to create a new setup intent in stripe
|
||||
@@ -114,15 +230,16 @@ async def create_customer_setup_session(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_info['customer_id'],
|
||||
customer=customer_id,
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# Endpoint to create a new Stripe checkout session for credit purchase
|
||||
@@ -134,9 +251,9 @@ async def create_checkout_session(
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_info['customer_id'],
|
||||
customer=customer_id,
|
||||
line_items=[
|
||||
{
|
||||
'price_data': {
|
||||
@@ -149,7 +266,7 @@ async def create_checkout_session(
|
||||
'tax_behavior': 'exclusive',
|
||||
},
|
||||
'quantity': 1,
|
||||
},
|
||||
}
|
||||
],
|
||||
mode='payment',
|
||||
payment_method_types=['card'],
|
||||
@@ -162,9 +279,8 @@ async def create_checkout_session(
|
||||
logger.info(
|
||||
'created_stripe_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_info['customer_id'],
|
||||
'stripe_customer_id': customer_id,
|
||||
'user_id': user_id,
|
||||
'org_id': customer_info['org_id'],
|
||||
'amount': body.amount,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
},
|
||||
@@ -173,14 +289,105 @@ async def create_checkout_session(
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
org_id=customer_info['org_id'],
|
||||
price=body.amount,
|
||||
price_code='NA',
|
||||
billing_session_type=BillingSessionType.DIRECT_PAYMENT.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@billing_router.post('/subscription-checkout-session')
|
||||
async def create_subscription_checkout_session(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
# Prevent duplicate subscriptions for the same user
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
existing_active_subscription = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_active_subscription:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot create subscription: User already has an active subscription that has not been cancelled',
|
||||
)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
subscription_price_data = SUBSCRIPTION_PRICE_DATA[billing_session_type.value]
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
line_items=[
|
||||
{
|
||||
'price_data': subscription_price_data,
|
||||
'quantity': 1,
|
||||
}
|
||||
],
|
||||
mode='subscription',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
subscription_data={
|
||||
'metadata': {
|
||||
'user_id': user_id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_subscription_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'user_id': user_id,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
price=subscription_price_data['unit_amount'],
|
||||
price_code='NA',
|
||||
billing_session_type=billing_session_type.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(
|
||||
redirect_url=typing.cast(str, checkout_session.url)
|
||||
)
|
||||
|
||||
|
||||
@billing_router.get('/create-subscription-checkout-session')
|
||||
async def create_subscription_checkout_session_via_get(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> RedirectResponse:
|
||||
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
|
||||
validate_saas_environment(request)
|
||||
|
||||
response = await create_subscription_checkout_session(
|
||||
request, billing_session_type, user_id
|
||||
)
|
||||
return RedirectResponse(response.redirect_url)
|
||||
|
||||
|
||||
# Callback endpoint for successful Stripe payments - updates user credits and billing session status
|
||||
@@ -202,6 +409,15 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Any non direct payment (Subscription) is processed in the invoice_payment.paid by the webhook
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
!= BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
stripe_session = stripe.checkout.Session.retrieve(session_id)
|
||||
if stripe_session.status != 'complete':
|
||||
# Hopefully this never happens - we get a redirect from stripe where the payment is not yet complete
|
||||
@@ -215,39 +431,31 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
user = await call_sync_from_async(
|
||||
UserStore.get_user_by_id, billing_session.user_id
|
||||
)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
new_max_budget = max_budget + add_credits
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
# Update max budget in litellm
|
||||
user_json = await _get_litellm_user(client, billing_session.user_id)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
new_max_budget = (
|
||||
(user_json.get('user_info') or {}).get('max_budget') or 0
|
||||
) + add_credits
|
||||
await _upsert_litellm_user(client, billing_session.user_id, new_max_budget)
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
str(user.current_org_id), new_max_budget
|
||||
)
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = add_credits
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'org_id': str(user.current_org_id),
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = amount_subtotal
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=success', status_code=302
|
||||
@@ -277,6 +485,206 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
session.merge(billing_session)
|
||||
session.commit()
|
||||
|
||||
# Redirect credit purchases to billing screen, subscriptions to LLM settings
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
== BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=cancel',
|
||||
status_code=302,
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
# If no billing session found, default to LLM settings (subscription flow)
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=cancel', status_code=302
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
|
||||
@billing_router.post('/stripe-webhook')
|
||||
async def stripe_webhook(request: Request) -> JSONResponse:
|
||||
"""Endpoint for stripe webhooks."""
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get('stripe-signature')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise HTTPException(status_code=400, detail=f'Invalid payload: {e}')
|
||||
except stripe.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise HTTPException(status_code=400, detail=f'Invalid signature: {e}')
|
||||
|
||||
# Handle the event
|
||||
logger.info('stripe_webhook_event', extra={'event': event})
|
||||
event_type = event['type']
|
||||
if event_type == 'invoice.paid':
|
||||
invoice = event['data']['object']
|
||||
amount_paid = invoice.amount_paid
|
||||
metadata = invoice.parent.subscription_details.metadata # type: ignore
|
||||
billing_session_type = metadata.billing_session_type
|
||||
assert (
|
||||
amount_paid == SUBSCRIPTION_PRICE_DATA[billing_session_type]['unit_amount']
|
||||
)
|
||||
user_id = metadata.user_id
|
||||
|
||||
start_at = datetime.now(UTC)
|
||||
if billing_session_type == BillingSessionType.MONTHLY_SUBSCRIPTION.value:
|
||||
end_at = start_at + relativedelta(months=1)
|
||||
else:
|
||||
raise ValueError(f'unknown_billing_session_type:{billing_session_type}')
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = SubscriptionAccess(
|
||||
status='ACTIVE',
|
||||
user_id=user_id,
|
||||
start_at=start_at,
|
||||
end_at=end_at,
|
||||
amount_paid=amount_paid,
|
||||
stripe_invoice_payment_id=invoice.payment_intent,
|
||||
stripe_subscription_id=invoice.subscription, # Store Stripe subscription ID
|
||||
)
|
||||
session.add(subscription_access)
|
||||
session.commit()
|
||||
elif event_type == 'customer.subscription.updated':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
# Handle subscription cancellation
|
||||
if subscription.get('cancel_at_period_end') is True:
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(
|
||||
SubscriptionAccess.stripe_subscription_id == subscription_id
|
||||
)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access and not subscription_access.cancelled_at:
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled_via_webhook',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
elif event_type == 'customer.subscription.deleted':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.stripe_subscription_id == subscription_id)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access:
|
||||
subscription_access.status = 'DISABLED'
|
||||
subscription_access.updated_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
# Reset user settings to free tier defaults
|
||||
reset_user_to_free_tier_settings(subscription_access.user_id)
|
||||
|
||||
logger.info(
|
||||
'subscription_expired_reset_to_free_tier',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info('stripe_webhook_unhandled_event_type', extra={'type': event_type})
|
||||
|
||||
return JSONResponse({'status': 'success'})
|
||||
|
||||
|
||||
def reset_user_to_free_tier_settings(user_id: str) -> None:
|
||||
"""Reset user settings to free tier defaults when subscription ends."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.llm_model = get_default_litellm_model()
|
||||
user_settings.llm_api_key = None
|
||||
user_settings.llm_api_key_for_byor = None
|
||||
user_settings.llm_base_url = LITE_LLM_API_URL
|
||||
user_settings.max_budget_per_task = None
|
||||
user_settings.confirmation_mode = False
|
||||
user_settings.enable_solvability_analysis = False
|
||||
user_settings.security_analyzer = 'llm'
|
||||
user_settings.agent = 'CodeActAgent'
|
||||
user_settings.language = 'en'
|
||||
user_settings.enable_default_condenser = True
|
||||
user_settings.enable_sound_notifications = False
|
||||
user_settings.enable_proactive_conversation_starters = True
|
||||
user_settings.user_consents_to_analytics = False
|
||||
|
||||
session.merge(user_settings)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_settings_reset_to_free_tier',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'reset_timestamp': datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _get_litellm_user(client: httpx.AsyncClient, user_id: str) -> dict:
|
||||
"""Get a user from litellm with the id matching that given.
|
||||
|
||||
If no such user exists, returns a dummy user in the format:
|
||||
`{'user_id': '<USER_ID>', 'user_info': {'spend': 0}, 'keys': [], 'teams': []}`
|
||||
"""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
async def _upsert_litellm_user(
|
||||
client: httpx.AsyncClient, user_id: str, max_budget: float
|
||||
):
|
||||
"""Insert / Update a user in litellm."""
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
json={
|
||||
'user_id': user_id,
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -6,7 +6,7 @@ from threading import Thread
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from sqlalchemy import func, select
|
||||
from storage.database import a_session_maker, engine, session_maker
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import wait_all
|
||||
@@ -127,7 +127,7 @@ def _db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
with session_maker() as session:
|
||||
num_users = session.query(User).count()
|
||||
num_users = session.query(UserSettings).count()
|
||||
time.sleep(delay)
|
||||
logger.info(
|
||||
'check',
|
||||
@@ -155,7 +155,7 @@ async def _a_db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
async with a_session_maker() as a_session:
|
||||
stmt = select(func.count(User.id))
|
||||
stmt = select(func.count(UserSettings.id))
|
||||
num_users = await a_session.execute(stmt)
|
||||
await asyncio.sleep(delay)
|
||||
logger.info(f'a_num_users:{num_users.scalar_one()}')
|
||||
|
||||
@@ -21,7 +21,7 @@ from server.utils.conversation_callback_utils import (
|
||||
update_conversation_stats,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.server.shared import conversation_manager
|
||||
|
||||
@@ -226,12 +226,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]:
|
||||
|
||||
def _get_user_id(conversation_id: str) -> str:
|
||||
with session_maker() as session:
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
return conversation_metadata.user_id
|
||||
|
||||
|
||||
async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None:
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.database import session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.server.shared import file_store
|
||||
@@ -33,10 +33,10 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
def _verify_conversation():
|
||||
with session_maker() as session:
|
||||
metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
@@ -59,8 +58,7 @@ async def github_events(
|
||||
)
|
||||
|
||||
try:
|
||||
# Add timeout to prevent hanging on slow/stalled clients
|
||||
payload = await asyncio.wait_for(request.body(), timeout=15.0)
|
||||
payload = await request.body()
|
||||
verify_github_signature(payload, x_hub_signature_256)
|
||||
|
||||
payload_data = await request.json()
|
||||
@@ -80,12 +78,6 @@ async def github_events(
|
||||
status_code=200,
|
||||
content={'message': 'GitHub events endpoint reached successfully.'},
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning('GitHub webhook request timed out waiting for request body')
|
||||
return JSONResponse(
|
||||
status_code=408,
|
||||
content={'error': 'Request timeout - client took too long to send data.'},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error processing GitHub event: {e}')
|
||||
return JSONResponse(status_code=400, content={'error': 'Invalid payload.'})
|
||||
|
||||
@@ -15,6 +15,7 @@ from integrations.slack.slack_manager import SlackManager
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import (
|
||||
KEYCLOAK_CLIENT_ID,
|
||||
KEYCLOAK_REALM_NAME,
|
||||
@@ -34,11 +35,9 @@ from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
from storage.slack_user import SlackUser
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.shared import config, sio
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
signature_verifier = SignatureVerifier(signing_secret=SLACK_SIGNING_SECRET)
|
||||
slack_router = APIRouter(prefix='/slack')
|
||||
@@ -80,14 +79,6 @@ async def install_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if not config.jwt_secret:
|
||||
logger.error('slack_install_callback_error JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
try:
|
||||
client = AsyncWebClient() # no prepared token needed for this
|
||||
# Complete the installation by calling oauth.v2.access API method
|
||||
@@ -103,17 +94,16 @@ async def install_callback(
|
||||
|
||||
# Create a state variable for keycloak oauth
|
||||
payload = {}
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if state:
|
||||
payload = jwt.decode(
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
payload['slack_user_id'] = authed_user.get('id')
|
||||
payload['bot_access_token'] = bot_access_token
|
||||
payload['team_id'] = team_id
|
||||
|
||||
state = jwt.encode(
|
||||
payload, config.jwt_secret.get_secret_value(), algorithm='HS256'
|
||||
)
|
||||
state = jwt.encode(payload, jwt_secret.get_secret_value(), algorithm='HS256')
|
||||
|
||||
# Redirect into keycloak
|
||||
scope = quote('openid email profile offline_access')
|
||||
@@ -159,16 +149,9 @@ async def keycloak_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if not config.jwt_secret:
|
||||
logger.error('problem_retrieving_keycloak_tokens JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
payload: dict[str, str] = jwt.decode(
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
slack_user_id = payload['slack_user_id']
|
||||
bot_access_token = payload['bot_access_token']
|
||||
@@ -197,13 +180,6 @@ async def keycloak_callback(
|
||||
|
||||
user_info = await token_manager.get_user_info(keycloak_access_token)
|
||||
keycloak_user_id = user_info['sub']
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, keycloak_user_id)
|
||||
if not user:
|
||||
return _html_response(
|
||||
title='Failed to authenticate.',
|
||||
description=f'Please re-login into <a href="{HOST_URL}" style="color:#ecedee;text-decoration:underline;">OpenHands Cloud</a>. Then try <a href="https://docs.all-hands.dev/usage/cloud/slack-installation" style="color:#ecedee;text-decoration:underline;">installing the OpenHands Slack App</a> again',
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# These tokens are offline access tokens - store them!
|
||||
await token_manager.store_offline_token(keycloak_user_id, keycloak_refresh_token)
|
||||
@@ -235,7 +211,6 @@ async def keycloak_callback(
|
||||
slack_display_name = slack_user_info.data['user']['profile']['display_name']
|
||||
slack_user = SlackUser(
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
org_id=user.current_org_id,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_display_name=slack_display_name,
|
||||
)
|
||||
@@ -330,7 +305,7 @@ async def on_form_interaction(request: Request, background_tasks: BackgroundTask
|
||||
|
||||
body = await request.body()
|
||||
form = await request.form()
|
||||
payload = json.loads(form.get('payload'))
|
||||
payload = json.loads(form.get('payload')) # type: ignore[arg-type]
|
||||
|
||||
logger.info('slack_on_form_interaction', extra={'payload': payload})
|
||||
|
||||
|
||||
@@ -20,10 +20,7 @@ from server.utils.conversation_callback_utils import (
|
||||
from sqlalchemy import orm
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
@@ -73,11 +70,6 @@ RUNTIME_CONVERSATION_URL = RUNTIME_URL_PATTERN + (
|
||||
else '/api/conversations/{conversation_id}'
|
||||
)
|
||||
|
||||
RUNTIME_USERNAME = os.getenv('RUNTIME_USERNAME')
|
||||
SU_TO_USER = os.getenv('SU_TO_USER', 'false')
|
||||
truthy = {'1', 'true', 't', 'yes', 'y', 'on'}
|
||||
SU_TO_USER = str(SU_TO_USER.lower() in truthy).lower()
|
||||
|
||||
# Time in seconds before a Redis entry is considered expired if not refreshed
|
||||
_REDIS_ENTRY_TIMEOUT_SECONDS = 300
|
||||
|
||||
@@ -533,18 +525,16 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
"""
|
||||
|
||||
with session_maker() as session:
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id
|
||||
)
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation_metadata_saas:
|
||||
if not conversation_metadata:
|
||||
raise ValueError(f'No conversation found {conversation_id}')
|
||||
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
return conversation_metadata.user_id
|
||||
|
||||
async def _get_runtime_status_from_nested_runtime(
|
||||
self, session_api_key: Any | None, nested_url: str, conversation_id: str
|
||||
@@ -782,11 +772,7 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
env_vars['SERVE_FRONTEND'] = '0'
|
||||
env_vars['RUNTIME'] = 'local'
|
||||
# TODO: In the long term we may come up with a more secure strategy for user management within the nested runtime.
|
||||
env_vars['USER'] = (
|
||||
RUNTIME_USERNAME
|
||||
if RUNTIME_USERNAME
|
||||
else ('openhands' if config.run_as_openhands else 'root')
|
||||
)
|
||||
env_vars['USER'] = 'openhands' if config.run_as_openhands else 'root'
|
||||
env_vars['PERMITTED_CORS_ORIGINS'] = ','.join(PERMITTED_CORS_ORIGINS)
|
||||
env_vars['port'] = '60000'
|
||||
# TODO: These values are static in the runtime-api project, but do not get copied into the runtime ENV
|
||||
@@ -803,7 +789,6 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
env_vars['INITIAL_NUM_WARM_SERVERS'] = '1'
|
||||
env_vars['INIT_GIT_IN_EMPTY_WORKSPACE'] = '1'
|
||||
env_vars['ENABLE_V1'] = '0'
|
||||
env_vars['SU_TO_USER'] = SU_TO_USER
|
||||
|
||||
# We need this for LLM traces tracking to identify the source of the LLM calls
|
||||
env_vars['WEB_HOST'] = WEB_HOST
|
||||
@@ -873,17 +858,9 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
with session_maker() as session:
|
||||
# Only include conversations updated in the past week
|
||||
one_week_ago = datetime.now(UTC) - timedelta(days=7)
|
||||
query = (
|
||||
session.query(StoredConversationMetadata.conversation_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
)
|
||||
query = session.query(StoredConversationMetadata.conversation_id).filter(
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
)
|
||||
user_conversation_ids = set(query)
|
||||
return user_conversation_ids
|
||||
@@ -957,16 +934,11 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
if conversation_metadata is None or conversation_metadata_saas is None:
|
||||
if conversation_metadata is None:
|
||||
# Conversation is running in different server
|
||||
return
|
||||
|
||||
user_id = conversation_metadata_saas.user_id
|
||||
user_id = conversation_metadata.user_id
|
||||
|
||||
# Get the id of the next event which is not present
|
||||
events_dir = get_conversation_events_dir(
|
||||
|
||||
@@ -11,6 +11,7 @@ from storage.conversation_callback import (
|
||||
)
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -125,12 +126,6 @@ def update_conversation_metadata(conversation_id: str, content: dict):
|
||||
conversation_id: The conversation ID to update
|
||||
content: The metadata content to update
|
||||
"""
|
||||
|
||||
# Local import fixes the lazy-loading problem
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'update_conversation_metadata',
|
||||
extra={
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
from storage.api_key import ApiKey
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.conversation_callback import CallbackStatus, ConversationCallback
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.experiment_assignment import ExperimentAssignment
|
||||
from storage.feedback import ConversationFeedback, Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_dc_conversation import JiraDcConversation
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
from storage.linear_conversation import LinearConversation
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
from storage.role import Role
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_team import SlackTeam
|
||||
from storage.slack_user import SlackUser
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.subscription_access_status import SubscriptionAccessStatus
|
||||
from storage.user import User
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
__all__ = [
|
||||
'ApiKey',
|
||||
'AuthTokens',
|
||||
'BillingSession',
|
||||
'BillingSessionType',
|
||||
'CallbackStatus',
|
||||
'ConversationCallback',
|
||||
'ConversationFeedback',
|
||||
'StoredConversationMetadataSaas',
|
||||
'ConversationWork',
|
||||
'ExperimentAssignment',
|
||||
'Feedback',
|
||||
'GithubAppInstallation',
|
||||
'GitlabWebhook',
|
||||
'JiraConversation',
|
||||
'JiraDcConversation',
|
||||
'JiraDcUser',
|
||||
'JiraDcWorkspace',
|
||||
'JiraUser',
|
||||
'JiraWorkspace',
|
||||
'LinearConversation',
|
||||
'LinearUser',
|
||||
'LinearWorkspace',
|
||||
'MaintenanceTask',
|
||||
'MaintenanceTaskStatus',
|
||||
'OpenhandsPR',
|
||||
'Org',
|
||||
'OrgMember',
|
||||
'ProactiveConversation',
|
||||
'Role',
|
||||
'SlackConversation',
|
||||
'SlackTeam',
|
||||
'SlackUser',
|
||||
'StoredConversationMetadata',
|
||||
'StoredOfflineToken',
|
||||
'StoredRepository',
|
||||
'StoredCustomSecrets',
|
||||
'StripeCustomer',
|
||||
'SubscriptionAccess',
|
||||
'SubscriptionAccessStatus',
|
||||
'User',
|
||||
'UserRepositoryMap',
|
||||
'UserSettings',
|
||||
'WebhookStatus',
|
||||
]
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -13,13 +11,9 @@ class ApiKey(Base):
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
key = Column(String(255), nullable=False, unique=True, index=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
name = Column(String(255), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
)
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='api_keys')
|
||||
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.api_key import ApiKey
|
||||
from storage.database import session_maker
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -37,15 +36,10 @@ class ApiKeyStore:
|
||||
The generated API key
|
||||
"""
|
||||
api_key = self.generate_api_key()
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
|
||||
with self.session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
key=api_key, user_id=user_id, name=name, expires_at=expires_at
|
||||
)
|
||||
session.add(key_record)
|
||||
session.commit()
|
||||
@@ -105,15 +99,8 @@ class ApiKeyStore:
|
||||
|
||||
def list_api_keys(self, user_id: str) -> list[dict]:
|
||||
"""List all API keys for a user."""
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
keys = session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
|
||||
return [
|
||||
{
|
||||
@@ -128,14 +115,9 @@ class ApiKeyStore:
|
||||
]
|
||||
|
||||
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
)
|
||||
for key in keys:
|
||||
if key.name == 'MCP_API_KEY':
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -13,9 +11,9 @@ class BillingSession(Base): # type: ignore
|
||||
"""
|
||||
|
||||
__tablename__ = 'billing_sessions'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
status = Column(
|
||||
Enum(
|
||||
'in_progress',
|
||||
@@ -26,6 +24,15 @@ class BillingSession(Base): # type: ignore
|
||||
),
|
||||
default='in_progress',
|
||||
)
|
||||
billing_session_type = Column(
|
||||
Enum(
|
||||
'DIRECT_PAYMENT',
|
||||
'MONTHLY_SUBSCRIPTION',
|
||||
name='billing_session_type_enum',
|
||||
),
|
||||
nullable=False,
|
||||
default='DIRECT_PAYMENT',
|
||||
)
|
||||
price = Column(DECIMAL(19, 4), nullable=False)
|
||||
price_code = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -36,6 +43,3 @@ class BillingSession(Base): # type: ignore
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC), # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='billing_sessions')
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
@@ -8,9 +7,6 @@ from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.util import await_only
|
||||
|
||||
# Check if we're running in a test environment
|
||||
IS_TESTING = 'pytest' in sys.modules
|
||||
|
||||
DB_HOST = os.environ.get('DB_HOST', 'localhost') # for non-GCP environments
|
||||
DB_PORT = os.environ.get('DB_PORT', '5432') # for non-GCP environments
|
||||
DB_USER = os.environ.get('DB_USER', 'postgres')
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
import binascii
|
||||
import hashlib
|
||||
from base64 import b64decode, b64encode
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from pydantic import SecretStr
|
||||
from server.config import get_config
|
||||
|
||||
_jwt_service = None
|
||||
_fernet = None
|
||||
|
||||
|
||||
def encrypt_model(encrypt_keys: list, model_instance) -> dict:
|
||||
return encrypt_kwargs(encrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def decrypt_model(decrypt_keys: list, model_instance) -> dict:
|
||||
return decrypt_kwargs(decrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def encrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if isinstance(value, dict):
|
||||
encrypt_kwargs(encrypt_keys, value)
|
||||
continue
|
||||
|
||||
if key in encrypt_keys:
|
||||
value = encrypt_value(value)
|
||||
kwargs[key] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def decrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = decrypt_value(value)
|
||||
kwargs[key] = value
|
||||
except binascii.Error:
|
||||
pass # Key is in legacy format...
|
||||
return kwargs
|
||||
|
||||
|
||||
def encrypt_value(value: str | SecretStr) -> str:
|
||||
return get_jwt_service().create_jwe_token(
|
||||
{'v': value.get_secret_value() if isinstance(value, SecretStr) else value}
|
||||
)
|
||||
|
||||
|
||||
def decrypt_value(value: str | SecretStr) -> str:
|
||||
token = get_jwt_service().decrypt_jwe_token(
|
||||
value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
)
|
||||
return token['v']
|
||||
|
||||
|
||||
def get_jwt_service():
|
||||
from openhands.app_server.config import get_global_config
|
||||
|
||||
global _jwt_service
|
||||
if _jwt_service is None:
|
||||
jwt_service_injector = get_global_config().jwt
|
||||
assert jwt_service_injector is not None
|
||||
_jwt_service = jwt_service_injector.get_jwt_service()
|
||||
return _jwt_service
|
||||
|
||||
|
||||
def decrypt_legacy_model(decrypt_keys: list, model_instance) -> dict:
|
||||
return decrypt_legacy_kwargs(decrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def decrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = decrypt_legacy_value(value)
|
||||
kwargs[key] = value
|
||||
except binascii.Error:
|
||||
pass # Key is in legacy format...
|
||||
except InvalidToken:
|
||||
pass # Key not encrypted...
|
||||
return kwargs
|
||||
|
||||
|
||||
def decrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
if isinstance(value, SecretStr):
|
||||
return (
|
||||
get_fernet().decrypt(b64decode(value.get_secret_value().encode())).decode()
|
||||
)
|
||||
else:
|
||||
return get_fernet().decrypt(b64decode(value.encode())).decode()
|
||||
|
||||
|
||||
def get_fernet():
|
||||
global _fernet
|
||||
if _fernet is None:
|
||||
jwt_secret = get_config().jwt_secret.get_secret_value()
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
_fernet = Fernet(fernet_key)
|
||||
return _fernet
|
||||
|
||||
|
||||
def model_to_kwargs(model_instance):
|
||||
return {
|
||||
column.name: getattr(model_instance, column.name)
|
||||
for column in model_instance.__table__.columns
|
||||
}
|
||||
@@ -1,16 +1,7 @@
|
||||
import sys
|
||||
from enum import IntEnum
|
||||
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy import ARRAY, Boolean, Column, DateTime, Integer, String, Text, text
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
|
||||
@@ -1,674 +0,0 @@
|
||||
"""
|
||||
Store class for managing organizational settings.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
class LiteLlmManager:
|
||||
"""Manage LiteLLM interactions."""
|
||||
|
||||
@staticmethod
|
||||
async def create_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
oss_settings: Settings,
|
||||
) -> Settings | None:
|
||||
logger.info(
|
||||
'SettingsStore:update_settings_with_litellm_default:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(keycloak_user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
)
|
||||
|
||||
await LiteLlmManager._create_user(
|
||||
client, keycloak_user_info.get('email'), keycloak_user_id
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
)
|
||||
|
||||
key = await LiteLlmManager._generate_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
f'OpenHands Cloud - user {keycloak_user_id}',
|
||||
None,
|
||||
)
|
||||
|
||||
oss_settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
oss_settings.llm_model = get_default_litellm_model()
|
||||
oss_settings.llm_api_key = SecretStr(key)
|
||||
oss_settings.llm_base_url = LITE_LLM_API_URL
|
||||
return oss_settings
|
||||
|
||||
@staticmethod
|
||||
async def migrate_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
user_settings: UserSettings,
|
||||
) -> UserSettings | None:
|
||||
logger.info(
|
||||
'SettingsStore:umigrate_lite_llm_entries:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
user_json = await LiteLlmManager._get_user(client, keycloak_user_id)
|
||||
if not user_json:
|
||||
return None
|
||||
user_info = user_json['user_info']
|
||||
max_budget = user_info.get('max_budget', 0.0)
|
||||
if not max_budget:
|
||||
# if max_budget is None, then we've already migrated the User
|
||||
return None
|
||||
spend = user_info.get('spend', 0.0)
|
||||
credits = max(max_budget - spend, 0.0)
|
||||
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
await LiteLlmManager._update_user(
|
||||
client, keycloak_user_id, max_budget=1000000000.0
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key:
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key,
|
||||
team_id=org_id,
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key_for_byor,
|
||||
team_id=org_id,
|
||||
)
|
||||
|
||||
user_settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
user_settings.llm_model = get_default_litellm_model()
|
||||
user_settings.llm_base_url = LITE_LLM_API_URL
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
async def update_team_and_users_budget(
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
await LiteLlmManager._update_team(client, team_id, None, max_budget)
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
return None
|
||||
for membership in team_info.get('team_memberships', []):
|
||||
user_id = membership.get('user_id')
|
||||
if not user_id:
|
||||
continue
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
client, user_id, team_id, max_budget
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _create_team(
|
||||
client: httpx.AsyncClient,
|
||||
team_alias: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/new',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'team_alias': team_alias,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': 0,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
# Team failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
response.status_code == 400
|
||||
and 'already exists. Please use a different team id' in response.text
|
||||
):
|
||||
# team already exists, so update, then return
|
||||
await LiteLlmManager._update_team(
|
||||
client, team_id, team_alias, max_budget
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_creating_litellm_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'team_id': team_id,
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_team(client: httpx.AsyncClient, team_id: str) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
"""Get a team from litellm with the id matching that given."""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/team/info?team_id={team_id}',
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
async def _update_team(
|
||||
client: httpx.AsyncClient,
|
||||
team_id: str,
|
||||
team_alias: str | None,
|
||||
max_budget: float | None,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget'] = max_budget
|
||||
|
||||
if team_alias is not None:
|
||||
json_data['team_alias'] = team_alias
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/update',
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
# Team failed to update in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _create_user(
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': email,
|
||||
'models': [],
|
||||
'user_id': keycloak_user_id,
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': False,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
'duplicate_user_email',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'email': email,
|
||||
},
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': None,
|
||||
'models': [],
|
||||
'user_id': keycloak_user_id,
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': False,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if response.status_code == 400 and 'already exists' in response.text:
|
||||
# user already exists, just return
|
||||
return
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'email': None,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_user(client: httpx.AsyncClient, user_id: str) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
"""Get a user from litellm with the id matching that given."""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
async def _update_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
**kwargs,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
payload = {
|
||||
'user_id': keycloak_user_id,
|
||||
}
|
||||
payload.update(kwargs)
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _update_key(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
key: str,
|
||||
**kwargs,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
payload = {
|
||||
'key': key,
|
||||
}
|
||||
payload.update(kwargs)
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/update',
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _delete_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [keycloak_user_id]}
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_deleting_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _add_user_to_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_add',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'member': {'user_id': keycloak_user_id, 'role': 'user'},
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
# Failed to add user to team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_adding_litellm_user_to_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_user_team_info(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
return None
|
||||
|
||||
# Filter team_memberships based on team_id and keycloak_user_id
|
||||
user_membership = next(
|
||||
(
|
||||
membership
|
||||
for membership in team_info.get('team_memberships', [])
|
||||
if membership.get('user_id') == keycloak_user_id
|
||||
and membership.get('team_id') == team_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return user_membership
|
||||
|
||||
@staticmethod
|
||||
async def _update_user_in_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_update',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
# Failed to update user in team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_user_in_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _generate_key(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str | None,
|
||||
key_alias: str | None,
|
||||
metadata: dict | None,
|
||||
) -> str | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
json_data: dict[str, Any] = {
|
||||
'user_id': keycloak_user_id,
|
||||
'models': [],
|
||||
}
|
||||
|
||||
if team_id is not None:
|
||||
json_data['team_id'] = team_id
|
||||
|
||||
if key_alias is not None:
|
||||
json_data['key_alias'] = key_alias
|
||||
|
||||
if metadata is not None:
|
||||
json_data['metadata'] = metadata
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json=json_data,
|
||||
)
|
||||
# Failed to generate user key for team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_generate_user_team_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
'key_alias': key_alias,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json['key']
|
||||
logger.info(
|
||||
'LiteLlmManager:_lite_llm_generate_user_team_key:key_created',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
'key_alias': key_alias,
|
||||
},
|
||||
)
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
async def _get_key_info(
|
||||
client: httpx.AsyncClient,
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
) -> dict | None:
|
||||
from storage.user_store import UserStore
|
||||
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, keycloak_user_id)
|
||||
if not user:
|
||||
return {}
|
||||
|
||||
org_member = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return {}
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/key/info?key={org_member.llm_api_key}'
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key_info = response_json.get('info')
|
||||
if not key_info:
|
||||
return {}
|
||||
return {
|
||||
'key_max_budget': key_info.get('max_budget'),
|
||||
'key_spend': key_info.get('spend'),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _delete_key(
|
||||
client: httpx.AsyncClient,
|
||||
key_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/delete',
|
||||
json={
|
||||
'keys': [key_id],
|
||||
},
|
||||
)
|
||||
# Failed to key...
|
||||
if not response.is_success:
|
||||
if response.status_code == 404:
|
||||
# key doesn't exist, just return
|
||||
return
|
||||
logger.error(
|
||||
'error_deleting_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
'LiteLlmManager:_delete_key:key_deleted',
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def with_http_client(
|
||||
internal_fn: Callable[..., Awaitable[Any]],
|
||||
) -> Callable[..., Awaitable[Any]]:
|
||||
@functools.wraps(internal_fn)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with httpx.AsyncClient(
|
||||
headers={'x-goog-api-key': LITE_LLM_API_KEY}
|
||||
) as client:
|
||||
return await internal_fn(client, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
# Public methods with injected client
|
||||
create_team = staticmethod(with_http_client(_create_team))
|
||||
get_team = staticmethod(with_http_client(_get_team))
|
||||
update_team = staticmethod(with_http_client(_update_team))
|
||||
create_user = staticmethod(with_http_client(_create_user))
|
||||
get_user = staticmethod(with_http_client(_get_user))
|
||||
update_user = staticmethod(with_http_client(_update_user))
|
||||
delete_user = staticmethod(with_http_client(_delete_user))
|
||||
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
|
||||
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
|
||||
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
|
||||
generate_key = staticmethod(with_http_client(_generate_key))
|
||||
get_key_info = staticmethod(with_http_client(_get_key_info))
|
||||
delete_key = staticmethod(with_http_client(_delete_key))
|
||||
@@ -1,117 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import SecretStr
|
||||
from server.constants import DEFAULT_BILLING_MARGIN
|
||||
from sqlalchemy import JSON, UUID, Boolean, Column, Float, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class Org(Base): # type: ignore
|
||||
"""Organization model."""
|
||||
|
||||
__tablename__ = 'org'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
contact_name = Column(String, nullable=True)
|
||||
contact_email = Column(String, nullable=True)
|
||||
agent = Column(String, nullable=True)
|
||||
default_max_iterations = Column(Integer, nullable=True)
|
||||
security_analyzer = Column(String, nullable=True)
|
||||
confirmation_mode = Column(Boolean, nullable=True, default=False)
|
||||
default_llm_model = Column(String, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_default_llm_api_key_for_byor = Column(String, nullable=True)
|
||||
default_llm_base_url = Column(String, nullable=True)
|
||||
remote_runtime_resource_factor = Column(Integer, nullable=True)
|
||||
enable_default_condenser = Column(Boolean, nullable=False, default=True)
|
||||
billing_margin = Column(Float, nullable=True, default=DEFAULT_BILLING_MARGIN)
|
||||
enable_proactive_conversation_starters = Column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
sandbox_base_container_image = Column(String, nullable=True)
|
||||
sandbox_runtime_container_image = Column(String, nullable=True)
|
||||
org_version = Column(Integer, nullable=False, default=0)
|
||||
mcp_config = Column(JSON, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_search_api_key = Column(String, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_sandbox_api_key = Column(String, nullable=True)
|
||||
max_budget_per_task = Column(Float, nullable=True)
|
||||
enable_solvability_analysis = Column(Boolean, nullable=True, default=False)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
conversation_expiration = Column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org_members = relationship('OrgMember', back_populates='org')
|
||||
current_users = relationship('User', back_populates='current_org')
|
||||
billing_sessions = relationship('BillingSession', back_populates='org')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='org'
|
||||
)
|
||||
user_secrets = relationship('StoredCustomSecrets', back_populates='org')
|
||||
api_keys = relationship('ApiKey', back_populates='org')
|
||||
slack_conversations = relationship('SlackConversation', back_populates='org')
|
||||
slack_users = relationship('SlackUser', back_populates='org')
|
||||
stripe_customers = relationship('StripeCustomer', back_populates='org')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'default_llm_api_key_for_byor' in kwargs:
|
||||
self.default_llm_api_key_for_byor = kwargs.pop(
|
||||
'default_llm_api_key_for_byor'
|
||||
)
|
||||
if 'search_api_key' in kwargs:
|
||||
self.search_api_key = kwargs.pop('search_api_key')
|
||||
if 'sandbox_api_key' in kwargs:
|
||||
self.sandbox_api_key = kwargs.pop('sandbox_api_key')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def default_llm_api_key_for_byor(self) -> SecretStr | None:
|
||||
if self._default_llm_api_key_for_byor:
|
||||
decrypted = decrypt_value(self._default_llm_api_key_for_byor)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@default_llm_api_key_for_byor.setter
|
||||
def default_llm_api_key_for_byor(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._default_llm_api_key_for_byor = encrypt_value(raw) if raw else None
|
||||
|
||||
@property
|
||||
def search_api_key(self) -> SecretStr | None:
|
||||
if self._search_api_key:
|
||||
decrypted = decrypt_value(self._search_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@search_api_key.setter
|
||||
def search_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._search_api_key = encrypt_value(raw) if raw else None
|
||||
|
||||
@property
|
||||
def sandbox_api_key(self) -> SecretStr | None:
|
||||
if self._sandbox_api_key:
|
||||
decrypted = decrypt_value(self._sandbox_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@sandbox_api_key.setter
|
||||
def sandbox_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._sandbox_api_key = encrypt_value(raw) if raw else None
|
||||
@@ -1,67 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization-Member relationship.
|
||||
"""
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import UUID, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class OrgMember(Base): # type: ignore
|
||||
"""Junction table for organization-member relationships with roles."""
|
||||
|
||||
__tablename__ = 'org_member'
|
||||
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), primary_key=True)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), primary_key=True)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
|
||||
_llm_api_key = Column(String, nullable=False)
|
||||
max_iterations = Column(Integer, nullable=True)
|
||||
llm_model = Column(String, nullable=True)
|
||||
_llm_api_key_for_byor = Column(String, nullable=True)
|
||||
llm_base_url = Column(String, nullable=True)
|
||||
status = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='org_members')
|
||||
user = relationship('User', back_populates='org_members')
|
||||
role = relationship('Role', back_populates='org_members')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'llm_api_key' in kwargs:
|
||||
self.llm_api_key = kwargs.pop('llm_api_key')
|
||||
if 'llm_api_key_for_byor' in kwargs:
|
||||
self.llm_api_key_for_byor = kwargs.pop('llm_api_key_for_byor')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def llm_api_key(self) -> SecretStr:
|
||||
decrypted = decrypt_value(self._llm_api_key)
|
||||
return SecretStr(decrypted)
|
||||
|
||||
@llm_api_key.setter
|
||||
def llm_api_key(self, value: str | SecretStr):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key = encrypt_value(raw)
|
||||
|
||||
@property
|
||||
def llm_api_key_for_byor(self) -> SecretStr | None:
|
||||
if self._llm_api_key_for_byor:
|
||||
decrypted = decrypt_value(self._llm_api_key_for_byor)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@llm_api_key_for_byor.setter
|
||||
def llm_api_key_for_byor(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key_for_byor = encrypt_value(raw) if raw else None
|
||||
@@ -1,125 +0,0 @@
|
||||
"""
|
||||
Store class for managing organization-member relationships.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from storage.database import session_maker
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgMemberStore:
|
||||
"""Store for managing organization-member relationships."""
|
||||
|
||||
@staticmethod
|
||||
def add_user_to_org(
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
role_id: int,
|
||||
llm_api_key: str,
|
||||
status: Optional[str] = None,
|
||||
) -> OrgMember:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
with session_maker() as session:
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
llm_api_key=llm_api_key,
|
||||
status=status,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: int) -> list[OrgMember]:
|
||||
"""Get all organizations for a user."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.user_id == user_id).all()
|
||||
|
||||
@staticmethod
|
||||
def get_org_members(org_id: UUID) -> list[OrgMember]:
|
||||
"""Get all users in an organization."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.org_id == org_id).all()
|
||||
|
||||
@staticmethod
|
||||
def update_org_member(org_member: OrgMember) -> None:
|
||||
"""Update an organization-member relationship."""
|
||||
with session_maker() as session:
|
||||
session.merge(org_member)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_user_role_in_org(
|
||||
org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None
|
||||
) -> Optional[OrgMember]:
|
||||
"""Update user's role in an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
org_member.role_id = role_id
|
||||
if status is not None:
|
||||
org_member.status = status
|
||||
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def remove_user_from_org(org_id: UUID, user_id: int) -> bool:
|
||||
"""Remove a user from an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return False
|
||||
|
||||
session.delete(org_member)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {
|
||||
normalized: getattr(settings, normalized)
|
||||
for c in OrgMember.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {
|
||||
normalized: getattr(user_settings, normalized)
|
||||
for c in OrgMember.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
@@ -1,139 +0,0 @@
|
||||
"""
|
||||
Store class for managing organizations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgStore:
|
||||
"""Store for managing organizations."""
|
||||
|
||||
@staticmethod
|
||||
def create_org(
|
||||
kwargs: dict,
|
||||
) -> Org:
|
||||
"""Create a new organization."""
|
||||
with session_maker() as session:
|
||||
org = Org(**kwargs)
|
||||
org.org_version = ORG_SETTINGS_VERSION
|
||||
org.default_llm_model = get_default_litellm_model()
|
||||
session.add(org)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_id(org_id: UUID) -> Org | None:
|
||||
"""Get organization by ID."""
|
||||
with session_maker() as session:
|
||||
return session.query(Org).filter(Org.id == org_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_current_org_from_keycloak_user_id(keycloak_user_id: str) -> Org | None:
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == UUID(keycloak_user_id))
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
logger.warning(f'User not found for ID {keycloak_user_id}')
|
||||
return None
|
||||
org_id = user.current_org_id
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
f'Org not found for ID {org_id} as the current org for user {keycloak_user_id}'
|
||||
)
|
||||
return None
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_name(name: str) -> Org | None:
|
||||
"""Get organization by name."""
|
||||
with session_maker() as session:
|
||||
return session.query(Org).filter(Org.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
def list_orgs() -> list[Org]:
|
||||
"""List all organizations."""
|
||||
with session_maker() as session:
|
||||
orgs = session.query(Org).all()
|
||||
return orgs
|
||||
|
||||
@staticmethod
|
||||
def update_org(
|
||||
org_id: UUID,
|
||||
kwargs: dict,
|
||||
) -> Optional[Org]:
|
||||
"""Update organization details."""
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
if 'id' in kwargs:
|
||||
kwargs.pop('id')
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {}
|
||||
|
||||
for c in Org.__table__.columns:
|
||||
# Normalize for lookup
|
||||
normalized = (
|
||||
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
|
||||
)
|
||||
|
||||
if not hasattr(settings, normalized):
|
||||
continue
|
||||
|
||||
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
|
||||
key = c.name
|
||||
if key.startswith('_'):
|
||||
key = key[1:] # remove only the very first leading underscore
|
||||
|
||||
kwargs[key] = getattr(settings, normalized)
|
||||
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {}
|
||||
|
||||
for c in Org.__table__.columns:
|
||||
# Normalize for lookup
|
||||
normalized = (
|
||||
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
|
||||
)
|
||||
|
||||
if not hasattr(user_settings, normalized):
|
||||
continue
|
||||
|
||||
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
|
||||
key = c.name
|
||||
if key.startswith('_'):
|
||||
key = key[1:] # remove only the very first leading underscore
|
||||
|
||||
kwargs[key] = getattr(user_settings, normalized)
|
||||
|
||||
return kwargs
|
||||
@@ -1,21 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy model for Role.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class Role(Base): # type: ignore
|
||||
"""Role model for user permissions."""
|
||||
|
||||
__tablename__ = 'role'
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
rank = Column(Integer, nullable=False)
|
||||
|
||||
# Relationships
|
||||
users = relationship('User', back_populates='role')
|
||||
org_members = relationship('OrgMember', back_populates='role')
|
||||
@@ -1,40 +0,0 @@
|
||||
"""
|
||||
Store class for managing roles.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from storage.database import session_maker
|
||||
from storage.role import Role
|
||||
|
||||
|
||||
class RoleStore:
|
||||
"""Store for managing roles."""
|
||||
|
||||
@staticmethod
|
||||
def create_role(name: str, rank: int) -> Role:
|
||||
"""Create a new role."""
|
||||
with session_maker() as session:
|
||||
role = Role(name=name, rank=rank)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
session.refresh(role)
|
||||
return role
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_id(role_id: int) -> Optional[Role]:
|
||||
"""Get role by ID."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.id == role_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_name(name: str) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
def list_roles() -> List[Role]:
|
||||
"""List all roles."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).order_by(Role.rank).all()
|
||||
@@ -1,350 +0,0 @@
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy import func, select
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user import User
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
||||
AppConversationInfoService,
|
||||
AppConversationInfoServiceInjector,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
AppConversationInfoPage,
|
||||
AppConversationSortOrder,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
|
||||
|
||||
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
|
||||
|
||||
async def _secure_select(self):
|
||||
query = (
|
||||
select(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def _secure_select_with_saas_metadata(self):
|
||||
"""Select query that includes SAAS metadata for retrieving user_id."""
|
||||
query = (
|
||||
select(StoredConversationMetadata, StoredConversationMetadataSaas)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def search_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> AppConversationInfoPage:
|
||||
"""Search for conversations with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
# Exclude sub-conversations (only include top-level conversations)
|
||||
query = query.where(
|
||||
StoredConversationMetadata.parent_conversation_id.is_(None)
|
||||
)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
# Add sort order
|
||||
if sort_order == AppConversationSortOrder.CREATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.created_at)
|
||||
elif sort_order == AppConversationSortOrder.CREATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.TITLE:
|
||||
query = query.order_by(StoredConversationMetadata.title)
|
||||
elif sort_order == AppConversationSortOrder.TITLE_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.title.desc())
|
||||
|
||||
# Apply pagination
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Apply limit and get one extra to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
for stored_metadata, saas_metadata in rows
|
||||
]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return AppConversationInfoPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def count_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count conversations matching the given filters with SAAS metadata."""
|
||||
query = (
|
||||
select(func.count(StoredConversationMetadata.conversation_id))
|
||||
.select_from(
|
||||
StoredConversationMetadata.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
# Apply user filtering
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
count = result.scalar()
|
||||
return count or 0
|
||||
|
||||
def _apply_filters_with_saas_metadata(
|
||||
self,
|
||||
query,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
):
|
||||
"""Apply filters to query that includes SAAS metadata."""
|
||||
# Apply the same filters as the base class
|
||||
conditions = []
|
||||
if title__contains is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.title.like(f'%{title__contains}%')
|
||||
)
|
||||
|
||||
if created_at__gte is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at >= created_at__gte)
|
||||
|
||||
if created_at__lt is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at < created_at__lt)
|
||||
|
||||
if updated_at__gte is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
||||
)
|
||||
|
||||
if updated_at__lt is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||
)
|
||||
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
return query
|
||||
|
||||
async def get_app_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> AppConversationInfo | None:
|
||||
"""Get conversation info with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
result_set = await self.db_session.execute(query)
|
||||
result = result_set.first()
|
||||
if result:
|
||||
stored_metadata, saas_metadata = result
|
||||
return self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
return None
|
||||
|
||||
async def batch_get_app_conversation_info(
|
||||
self, conversation_ids: list[UUID]
|
||||
) -> list[AppConversationInfo | None]:
|
||||
"""Batch get conversation info with user_id from SAAS metadata."""
|
||||
conversation_id_strs = [
|
||||
str(conversation_id) for conversation_id in conversation_ids
|
||||
]
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id.in_(conversation_id_strs)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Create a mapping of conversation_id to (metadata, saas_metadata)
|
||||
info_by_id = {}
|
||||
for stored_metadata, saas_metadata in rows:
|
||||
info_by_id[stored_metadata.conversation_id] = (
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
)
|
||||
|
||||
results: list[AppConversationInfo | None] = []
|
||||
for conversation_id in conversation_id_strs:
|
||||
if conversation_id in info_by_id:
|
||||
stored_metadata, saas_metadata = info_by_id[conversation_id]
|
||||
results.append(
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
)
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
return results
|
||||
|
||||
async def save_app_conversation_info(
|
||||
self, info: AppConversationInfo
|
||||
) -> AppConversationInfo:
|
||||
"""Save conversation info and create/update SAAS metadata with user_id and org_id."""
|
||||
# Save the base conversation metadata
|
||||
await super().save_app_conversation_info(info)
|
||||
|
||||
# Get current user_id for SAAS metadata
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
# Convert string user_id to UUID
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
user_query = select(User).where(User.id == user_id_uuid)
|
||||
result = await self.db_session.execute(user_query)
|
||||
user = result.scalar_one_or_none()
|
||||
assert user
|
||||
|
||||
# Check if SAAS metadata already exists
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(info.id)
|
||||
)
|
||||
result = await self.db_session.execute(saas_query)
|
||||
existing_saas_metadata = result.scalar_one_or_none()
|
||||
assert existing_saas_metadata is None or (
|
||||
existing_saas_metadata.user_id == user_id_uuid
|
||||
and existing_saas_metadata.org_id == user.current_org_id
|
||||
)
|
||||
|
||||
if not existing_saas_metadata:
|
||||
# Create new SAAS metadata
|
||||
# Set org_id to user_id as specified in requirements
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=str(info.id),
|
||||
user_id=user_id_uuid,
|
||||
org_id=user.current_org_id,
|
||||
)
|
||||
self.db_session.add(saas_metadata)
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
return info
|
||||
|
||||
def _to_info_with_user_id(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas,
|
||||
) -> AppConversationInfo:
|
||||
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
|
||||
# Use the base _to_info method to get the basic info
|
||||
info = self._to_info(stored)
|
||||
|
||||
# Override the created_by_user_id with the user_id from SAAS metadata
|
||||
info.created_by_user_id = (
|
||||
str(saas_metadata.user_id) if saas_metadata.user_id else None
|
||||
)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class SaasAppConversationInfoServiceInjector(AppConversationInfoServiceInjector):
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[AppConversationInfoService, None]:
|
||||
from openhands.app_server.config import (
|
||||
get_db_session,
|
||||
get_user_context,
|
||||
)
|
||||
|
||||
async with (
|
||||
get_user_context(state, request) as user_context,
|
||||
get_db_session(state, request) as db_session,
|
||||
):
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=db_session, user_context=user_context
|
||||
)
|
||||
yield service
|
||||
@@ -4,15 +4,10 @@ import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user_store import UserStore
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
@@ -34,37 +29,20 @@ logger = logging.getLogger(__name__)
|
||||
class SaasConversationStore(ConversationStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
org_id: UUID | None = None # will be fetched automatically
|
||||
|
||||
def __init__(self, user_id: str, session_maker: sessionmaker):
|
||||
self.user_id = user_id
|
||||
self.session_maker = session_maker
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
self.org_id = user.current_org_id if user else None
|
||||
|
||||
def _select_by_id(self, session, conversation_id: str):
|
||||
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
|
||||
query = (
|
||||
return (
|
||||
session.query(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.user_id == UUID(self.user_id))
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
)
|
||||
|
||||
if self.org_id is not None:
|
||||
query = query.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
|
||||
return query
|
||||
|
||||
def _to_external_model(self, conversation_metadata: StoredConversationMetadata):
|
||||
kwargs = {
|
||||
c.name: getattr(conversation_metadata, c.name)
|
||||
for c in StoredConversationMetadata.__table__.columns
|
||||
if c.name != 'github_user_id' # Skip github_user_id field
|
||||
}
|
||||
# TODO: I'm not sure why the timezone is not set on the dates coming back out of the db
|
||||
kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC)
|
||||
@@ -75,8 +53,6 @@ class SaasConversationStore(ConversationStore):
|
||||
# Convert string to ProviderType enum
|
||||
kwargs['git_provider'] = ProviderType(kwargs['git_provider'])
|
||||
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Remove V1 attributes
|
||||
kwargs.pop('max_budget_per_task', None)
|
||||
kwargs.pop('cache_read_tokens', None)
|
||||
@@ -90,10 +66,7 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
async def save_metadata(self, metadata: ConversationMetadata):
|
||||
kwargs = dataclasses.asdict(metadata)
|
||||
|
||||
# Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata
|
||||
kwargs.pop('user_id', None)
|
||||
kwargs.pop('org_id', None)
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Convert ProviderType enum to string for storage
|
||||
if kwargs.get('git_provider') is not None:
|
||||
@@ -107,41 +80,7 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
def _save_metadata():
|
||||
with self.session_maker() as session:
|
||||
# Save the main conversation metadata
|
||||
session.merge(stored_metadata)
|
||||
|
||||
# Create or update the SaaS metadata record
|
||||
saas_metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== stored_metadata.conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not saas_metadata:
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=stored_metadata.conversation_id,
|
||||
user_id=UUID(self.user_id),
|
||||
org_id=self.org_id,
|
||||
)
|
||||
session.add(saas_metadata)
|
||||
else:
|
||||
# Validate
|
||||
expected_user_id = UUID(self.user_id)
|
||||
expected_org_id = self.org_id
|
||||
|
||||
if saas_metadata.user_id != expected_user_id:
|
||||
raise ValueError(
|
||||
f'Existing user_id ({saas_metadata.user_id}) does not match expected value ({expected_user_id}).'
|
||||
)
|
||||
|
||||
if expected_org_id and saas_metadata.org_id != expected_org_id:
|
||||
raise ValueError(
|
||||
f'Existing org_id ({saas_metadata.org_id}) does not match expected value ({expected_org_id}).'
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_metadata)
|
||||
@@ -161,29 +100,8 @@ class SaasConversationStore(ConversationStore):
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
def _delete_metadata():
|
||||
with self.session_maker() as session:
|
||||
saas_record = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id),
|
||||
StoredConversationMetadataSaas.org_id == self.org_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if saas_record:
|
||||
# Delete both records, but only if the SaaS one exists
|
||||
session.query(StoredConversationMetadata).filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
).delete()
|
||||
|
||||
session.delete(saas_record)
|
||||
|
||||
session.commit()
|
||||
else:
|
||||
# No SaaS record found → skip deleting main metadata
|
||||
session.rollback()
|
||||
self._select_by_id(session, conversation_id).delete()
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_delete_metadata)
|
||||
|
||||
@@ -206,15 +124,7 @@ class SaasConversationStore(ConversationStore):
|
||||
with self.session_maker() as session:
|
||||
conversations = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id)
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
.order_by(StoredConversationMetadata.created_at.desc())
|
||||
.offset(offset)
|
||||
|
||||
@@ -8,13 +8,11 @@ from cryptography.fernet import Fernet
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -26,17 +24,14 @@ class SaasSecretsStore(SecretsStore):
|
||||
async def load(self) -> Secrets | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Fetch all secrets for the given user ID
|
||||
query = session.query(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
settings = (
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
)
|
||||
if org_id is not None:
|
||||
query = query.filter(StoredCustomSecrets.org_id == org_id)
|
||||
settings = query.all()
|
||||
|
||||
if not settings:
|
||||
return Secrets()
|
||||
@@ -53,8 +48,6 @@ class SaasSecretsStore(SecretsStore):
|
||||
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
|
||||
|
||||
async def store(self, item: Secrets):
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete all existing records and override with incoming ones
|
||||
@@ -83,7 +76,6 @@ class SaasSecretsStore(SecretsStore):
|
||||
for secret_name, secret_value, description in secret_tuples:
|
||||
new_secret = StoredCustomSecrets(
|
||||
keycloak_user_id=self.user_id,
|
||||
org_id=org_id,
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
description=description,
|
||||
|
||||
@@ -2,37 +2,45 @@ from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import hashlib
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
from base64 import b64decode, b64encode
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
from cryptography.fernet import Fernet
|
||||
from integrations import stripe_service
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
CURRENT_USER_SETTINGS_VERSION,
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
REQUIRE_PAYMENT,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import joinedload, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.settings.settings_store import SettingsStore as OssSettingsStore
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaasSettingsStore(OssSettingsStore):
|
||||
class SaasSettingsStore(SettingsStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
|
||||
|
||||
def _get_user_settings_by_keycloak_id(
|
||||
def get_user_settings_by_keycloak_id(
|
||||
self, keycloak_user_id: str, session=None
|
||||
) -> UserSettings | None:
|
||||
"""
|
||||
@@ -68,104 +76,246 @@ class SaasSettingsStore(OssSettingsStore):
|
||||
return _get_settings()
|
||||
|
||||
async def load(self) -> Settings | None:
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
if not user:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
if not self.user_id:
|
||||
return None
|
||||
with self.session_maker() as session:
|
||||
settings = self.get_user_settings_by_keycloak_id(self.user_id, session)
|
||||
|
||||
org_id = user.current_org_id
|
||||
org_member: OrgMember = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
kwargs = {
|
||||
**{
|
||||
normalized: getattr(org, c.name)
|
||||
for c in Org.__table__.columns
|
||||
if (
|
||||
normalized := c.name.removeprefix('_default_')
|
||||
.removeprefix('default_')
|
||||
.lstrip('_')
|
||||
if not settings or settings.user_version != CURRENT_USER_SETTINGS_VERSION:
|
||||
logger.info(
|
||||
'saas_settings_store:load:triggering_migration',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
in Settings.model_fields
|
||||
},
|
||||
**{
|
||||
normalized: getattr(user, c.name)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) in Settings.model_fields
|
||||
},
|
||||
}
|
||||
kwargs['llm_api_key'] = org_member.llm_api_key
|
||||
if org_member.max_iterations:
|
||||
kwargs['max_iterations'] = org_member.max_iterations
|
||||
if org_member.llm_model:
|
||||
kwargs['llm_model'] = org_member.llm_model
|
||||
if org_member.llm_api_key_for_byor:
|
||||
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
|
||||
if org_member.llm_base_url:
|
||||
kwargs['llm_base_url'] = org_member.llm_base_url
|
||||
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
return await self.create_default_settings(settings)
|
||||
kwargs = {
|
||||
c.name: getattr(settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
# Call the static store method from SettingsStore
|
||||
with self.session_maker() as session:
|
||||
if not item:
|
||||
return None
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(self.user_id))
|
||||
).first()
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if item and self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item)
|
||||
|
||||
with self.session_maker() as session:
|
||||
existing = None
|
||||
kwargs = {}
|
||||
if item:
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
self._encrypt_kwargs(kwargs)
|
||||
# First check if we have an existing entry in the new table
|
||||
existing = self.get_user_settings_by_keycloak_id(self.user_id, session)
|
||||
|
||||
kwargs = {
|
||||
key: value
|
||||
for key, value in kwargs.items()
|
||||
if key in UserSettings.__table__.columns
|
||||
}
|
||||
if existing:
|
||||
# Update existing entry
|
||||
for key, value in kwargs.items():
|
||||
setattr(existing, key, value)
|
||||
existing.user_version = CURRENT_USER_SETTINGS_VERSION
|
||||
session.merge(existing)
|
||||
else:
|
||||
kwargs['keycloak_user_id'] = self.user_id
|
||||
kwargs['user_version'] = CURRENT_USER_SETTINGS_VERSION
|
||||
kwargs.pop('secrets_store', None) # Don't save secrets_store to db
|
||||
settings = UserSettings(**kwargs)
|
||||
session.add(settings)
|
||||
session.commit()
|
||||
|
||||
async def create_default_settings(self, user_settings: UserSettings | None):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not self.user_id:
|
||||
return None
|
||||
|
||||
# Only users that have specified a payment method get default settings
|
||||
if REQUIRE_PAYMENT and not await stripe_service.has_payment_method(
|
||||
self.user_id
|
||||
):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:no_payment',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
settings: Settings | None = None
|
||||
if user_settings is None:
|
||||
settings = Settings(
|
||||
language='en',
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
elif isinstance(user_settings, UserSettings):
|
||||
# Convert UserSettings (SQLAlchemy model) to Settings (Pydantic model)
|
||||
kwargs = {
|
||||
c.name: getattr(user_settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
|
||||
if settings:
|
||||
settings = await self.update_settings_with_litellm_default(settings)
|
||||
if settings is None:
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:litellm_update_failed',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
await self.store(settings)
|
||||
return settings
|
||||
|
||||
async def load_legacy_file_store_settings(self, github_user_id: str):
|
||||
if not github_user_id:
|
||||
return None
|
||||
|
||||
file_store = get_file_store(self.config.file_store, self.config.file_store_path)
|
||||
path = f'users/github/{github_user_id}/settings.json'
|
||||
|
||||
try:
|
||||
json_str = await call_sync_from_async(file_store.read, path)
|
||||
logger.info(
|
||||
'saas_settings_store:load_legacy_file_store_settings:found',
|
||||
extra={'github_user_id': github_user_id},
|
||||
)
|
||||
kwargs = json.loads(json_str)
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'saas_settings_store:load_legacy_file_store_settings:error',
|
||||
extra={'github_user_id': github_user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
async def update_settings_with_litellm_default(
|
||||
self, settings: Settings
|
||||
) -> Settings | None:
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(self.user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
# Get the previous max budget to prevent accidental loss
|
||||
# In Litellm a get always succeeds, regardless of whether the user actually exists
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={self.user_id}'
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
user_info = response_json.get('user_info') or {}
|
||||
logger.info(
|
||||
f'creating_litellm_user: {self.user_id}; prev_max_budget: {user_info.get("max_budget")}; prev_metadata: {user_info.get("metadata")}'
|
||||
)
|
||||
max_budget = user_info.get('max_budget') or DEFAULT_INITIAL_BUDGET
|
||||
spend = user_info.get('spend') or 0
|
||||
|
||||
if not user:
|
||||
# Check if we need to migrate from user_settings
|
||||
user_settings = None
|
||||
with session_maker() as session:
|
||||
user_settings = self._get_user_settings_by_keycloak_id(
|
||||
user_settings = self.get_user_settings_by_keycloak_id(
|
||||
self.user_id, session
|
||||
)
|
||||
if user_settings:
|
||||
user = await UserStore.migrate_user(self.user_id, user_settings)
|
||||
else:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
|
||||
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
|
||||
if (
|
||||
user_settings
|
||||
and user_settings.user_version < 4
|
||||
and user_settings.billing_margin
|
||||
and user_settings.billing_margin != 1.0
|
||||
):
|
||||
billing_margin = user_settings.billing_margin
|
||||
logger.info(
|
||||
'user_settings_v4_budget_upgrade',
|
||||
extra={
|
||||
'max_budget': max_budget,
|
||||
'billing_margin': billing_margin,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
max_budget *= billing_margin
|
||||
spend *= billing_margin
|
||||
user_settings.billing_margin = 1.0
|
||||
session.commit()
|
||||
|
||||
email = keycloak_user_info.get('email')
|
||||
|
||||
# We explicitly delete here to guard against odd inherited settings on upgrade.
|
||||
# We don't care if this fails with a 404
|
||||
await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [self.user_id]}
|
||||
)
|
||||
|
||||
# Create the new litellm user
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, email, max_budget, spend
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
'duplicate_user_email',
|
||||
extra={'user_id': self.user_id, 'email': email},
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, None, max_budget, spend
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [self.user_id],
|
||||
'email': email,
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
org_id = user.current_org_id
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item, str(org_id))
|
||||
org_member = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
response_json = response.json()
|
||||
key = response_json['key']
|
||||
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:user_created',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
for model in (user, org, org_member):
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
session.commit()
|
||||
settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
settings.llm_model = get_default_litellm_model()
|
||||
settings.llm_api_key = SecretStr(key)
|
||||
settings.llm_base_url = LITE_LLM_API_URL
|
||||
return settings
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
@@ -176,9 +326,6 @@ class SaasSettingsStore(OssSettingsStore):
|
||||
logger.debug(f'saas_settings_store.get_instance::{user_id}')
|
||||
return SaasSettingsStore(user_id, session_maker, config)
|
||||
|
||||
def _should_encrypt(self, key):
|
||||
return key in self.ENCRYPT_VALUES
|
||||
|
||||
def _decrypt_kwargs(self, kwargs: dict):
|
||||
fernet = self._fernet()
|
||||
for key, value in kwargs.items():
|
||||
@@ -222,24 +369,21 @@ class SaasSettingsStore(OssSettingsStore):
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
return Fernet(fernet_key)
|
||||
|
||||
def _should_encrypt(self, key: str) -> bool:
|
||||
return key in ('llm_api_key', 'llm_api_key_for_byor', 'search_api_key')
|
||||
|
||||
def _is_openhands_provider(self, item: Settings) -> bool:
|
||||
"""Check if the settings use the OpenHands provider."""
|
||||
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
|
||||
|
||||
async def _ensure_openhands_api_key(self, item: Settings, org_id: str) -> None:
|
||||
async def _ensure_openhands_api_key(self, item: Settings) -> None:
|
||||
"""Generate and set the OpenHands API key for the given settings.
|
||||
|
||||
First checks if an existing key with the OpenHands alias exists,
|
||||
and reuses it if found. Otherwise, generates a new key.
|
||||
"""
|
||||
# Generate new key if none exists
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
self.user_id,
|
||||
org_id,
|
||||
f'Openhands Provider Key - user {self.user_id}',
|
||||
{'type': 'openhands'},
|
||||
)
|
||||
|
||||
generated_key = await self._generate_openhands_key()
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
@@ -251,3 +395,78 @@ class SaasSettingsStore(OssSettingsStore):
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
|
||||
async def _create_user_in_lite_llm(
|
||||
self, client: httpx.AsyncClient, email: str | None, max_budget: int, spend: int
|
||||
):
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': email,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
'user_id': str(self.user_id),
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': True,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': CURRENT_USER_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
'key_alias': f'OpenHands Cloud - user {self.user_id}',
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
async def _generate_openhands_key(self) -> str | None:
|
||||
"""Generate a new OpenHands provider key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'saas_settings_store:_generate_openhands_key:litellm_config_not_found',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
'user_id': self.user_id,
|
||||
'metadata': {'type': 'openhands'},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'saas_settings_store:_generate_openhands_key:success',
|
||||
extra={
|
||||
'user_id': self.user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': (
|
||||
key[:10] + '...' if key and len(key) > 10 else key
|
||||
),
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'saas_settings_store:_generate_openhands_key:no_key_in_response',
|
||||
extra={'user_id': self.user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'saas_settings_store:_generate_openhands_key:error',
|
||||
extra={'user_id': self.user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -10,8 +8,4 @@ class SlackConversation(Base): # type: ignore
|
||||
conversation_id = Column(String, nullable=False, index=True)
|
||||
channel_id = Column(String, nullable=False)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
parent_id = Column(String, nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_conversations')
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Identity, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, DateTime, Identity, Integer, String, text
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -8,7 +6,6 @@ class SlackUser(Base): # type: ignore
|
||||
__tablename__ = 'slack_users'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
slack_user_id = Column(String, nullable=False, index=True)
|
||||
slack_display_name = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -16,6 +13,3 @@ class SlackUser(Base): # type: ignore
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_users')
|
||||
|
||||
@@ -1,22 +1,8 @@
|
||||
def _get_stored_conversation_metadata():
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata as _StoredConversationMetadata,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata as _StoredConversationMetadata,
|
||||
)
|
||||
|
||||
return _StoredConversationMetadata
|
||||
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
StoredConversationMetadata = None
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
global StoredConversationMetadata
|
||||
if name == 'StoredConversationMetadata':
|
||||
if StoredConversationMetadata is None:
|
||||
StoredConversationMetadata = _get_stored_conversation_metadata()
|
||||
return StoredConversationMetadata
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
StoredConversationMetadata = _StoredConversationMetadata
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadata']
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy model for ConversationMetadataSaas.
|
||||
|
||||
This model stores the SaaS-specific metadata for conversations,
|
||||
containing only the conversation_id, user_id, and org_id.
|
||||
"""
|
||||
|
||||
from sqlalchemy import UUID as SQL_UUID
|
||||
from sqlalchemy import Column, ForeignKey, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class StoredConversationMetadataSaas(Base): # type: ignore
|
||||
"""SaaS conversation metadata model containing user and org associations."""
|
||||
|
||||
__tablename__ = 'conversation_metadata_saas'
|
||||
|
||||
conversation_id = Column(String, primary_key=True)
|
||||
user_id = Column(SQL_UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
|
||||
# Relationships
|
||||
user = relationship('User', back_populates='stored_conversation_metadata_saas')
|
||||
org = relationship('Org', back_populates='stored_conversation_metadata_saas')
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadataSaas']
|
||||
@@ -1,6 +1,4 @@
|
||||
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -8,10 +6,6 @@ class StoredCustomSecrets(Base): # type: ignore
|
||||
__tablename__ = 'custom_secrets'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=True, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
secret_name = Column(String, nullable=False)
|
||||
secret_value = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='user_secrets')
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -15,7 +13,6 @@ class StripeCustomer(Base): # type: ignore
|
||||
__tablename__ = 'stripe_customers'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
stripe_customer_id = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
@@ -26,6 +23,3 @@ class StripeCustomer(Base): # type: ignore
|
||||
onupdate=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='stripe_customers')
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy model for User.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import (
|
||||
UUID,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class User(Base): # type: ignore
|
||||
"""User model with organizational relationships."""
|
||||
|
||||
__tablename__ = 'user'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
current_org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=True)
|
||||
accepted_tos = Column(DateTime, nullable=True)
|
||||
enable_sound_notifications = Column(Boolean, nullable=True)
|
||||
language = Column(String, nullable=True)
|
||||
user_consents_to_analytics = Column(Boolean, nullable=True)
|
||||
email = Column(String, nullable=True)
|
||||
email_verified = Column(Boolean, nullable=True)
|
||||
|
||||
# Relationships
|
||||
role = relationship('Role', back_populates='users')
|
||||
org_members = relationship('OrgMember', back_populates='user')
|
||||
current_org = relationship('Org', back_populates='current_users')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='user'
|
||||
)
|
||||
@@ -39,6 +39,3 @@ class UserSettings(Base): # type: ignore
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
already_migrated = Column(
|
||||
Boolean, nullable=True, default=False
|
||||
) # False = not migrated, True = migrated
|
||||
|
||||
@@ -1,332 +0,0 @@
|
||||
"""
|
||||
Store class for managing users.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from server.logger import logger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.encrypt_utils import decrypt_legacy_model
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
|
||||
class UserStore:
|
||||
"""Store for managing users."""
|
||||
|
||||
@staticmethod
|
||||
async def create_user(
|
||||
user_id: str,
|
||||
user_info: dict,
|
||||
role_id: Optional[int] = None,
|
||||
) -> User | None:
|
||||
"""Create a new user."""
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_name=user_info['preferred_username'],
|
||||
contact_email=user_info['email'],
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
settings = await UserStore.create_default_settings(
|
||||
org_id=str(org.id), user_id=user_id
|
||||
)
|
||||
|
||||
if not settings:
|
||||
return None
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_settings(settings)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=role_id,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_settings(settings)
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # owner of your own org.
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def migrate_user(
|
||||
user_id: str,
|
||||
user_settings: UserSettings,
|
||||
user_info: dict,
|
||||
) -> User:
|
||||
if not user_id or not user_settings:
|
||||
return None
|
||||
|
||||
kwargs = decrypt_legacy_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
],
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_name=user_info['username'],
|
||||
contact_email=user_info['email'],
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
await LiteLlmManager.migrate_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
)
|
||||
|
||||
# avoids circular reference. This migrate method is temprorary until all users are migrated.
|
||||
from integrations.stripe_service import migrate_customer
|
||||
|
||||
await migrate_customer(session, user_id, org)
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_user_settings(decrypted_user_settings)
|
||||
org_kwargs.pop('id', None)
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_user_settings(
|
||||
decrypted_user_settings
|
||||
)
|
||||
user_kwargs.pop('id', None)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=None,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_user_settings(
|
||||
decrypted_user_settings
|
||||
)
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # owner of your own org.
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
session.add(org_member)
|
||||
|
||||
# Mark the old user_settings as migrated instead of deleting
|
||||
user_settings.already_migrated = True
|
||||
session.merge(user_settings)
|
||||
session.flush()
|
||||
|
||||
# need to migrate conversation metadata
|
||||
session.execute(
|
||||
text("""
|
||||
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||
SELECT
|
||||
conversation_id,
|
||||
:user_id,
|
||||
:user_id
|
||||
FROM conversation_metadata
|
||||
WHERE user_id = :user_id
|
||||
"""),
|
||||
{'user_id': user_id},
|
||||
)
|
||||
|
||||
# Update org_id for tables that had org_id added
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
|
||||
# Update stripe_customers
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update slack_users
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update slack_conversation
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update api_keys
|
||||
session.execute(
|
||||
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update custom_secrets
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update billing_sessions
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID."""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Check if we need to migrate from user_settings
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
token_manager = TokenManager()
|
||||
user_info = call_async_from_sync(
|
||||
token_manager.get_user_info_from_user_id,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
)
|
||||
user = call_async_from_sync(
|
||||
UserStore.migrate_user,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
"""List all users."""
|
||||
with session_maker() as session:
|
||||
return session.query(User).all()
|
||||
|
||||
# Prevent circular imports
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
@staticmethod
|
||||
async def create_default_settings(
|
||||
org_id: str, user_id: str
|
||||
) -> Optional['Settings']:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:start',
|
||||
extra={'org_id': org_id, 'user_id': user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not org_id:
|
||||
return None
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
settings = Settings(language='en', enable_proactive_conversation_starters=True)
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
settings = await LiteLlmManager.create_entries(org_id, user_id, settings)
|
||||
if not settings:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:litellm_create_failed',
|
||||
extra={'org_id': org_id},
|
||||
)
|
||||
return None
|
||||
|
||||
return settings
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: 'Settings'):
|
||||
kwargs = {
|
||||
normalized: getattr(settings, normalized)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {
|
||||
normalized: getattr(user_settings, normalized)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
@@ -3,7 +3,7 @@ from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from integrations.types import GitLabResourceType
|
||||
from server.constants import WEB_HOST
|
||||
from integrations.utils import GITLAB_WEBHOOK_URL
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
@@ -11,7 +11,6 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.service_types import GitService
|
||||
|
||||
GITLAB_WEBHOOK_URL = f'https://{WEB_HOST}/integration/gitlab/events'
|
||||
CHUNK_SIZE = 100
|
||||
WEBHOOK_NAME = 'OpenHands Resolver'
|
||||
SCOPES: list[str] = [
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.constants import ORG_SETTINGS_VERSION
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import (
|
||||
UserVersionUpgradeProcessor,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
@@ -13,20 +14,11 @@ from storage.billing_session import BillingSession
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.feedback import Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user import User
|
||||
|
||||
# Import the actual StoredConversationMetadata from OpenHands core
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -75,6 +67,7 @@ def add_minimal_fixtures(session_maker):
|
||||
session.add(
|
||||
StoredConversationMetadata(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id='mock-user-id',
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
last_updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
accumulated_cost=5.25,
|
||||
@@ -83,13 +76,6 @@ def add_minimal_fixtures(session_maker):
|
||||
total_tokens=750,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredConversationMetadataSaas(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredOfflineToken(
|
||||
user_id='mock-user-id',
|
||||
@@ -98,38 +84,7 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
Org(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
name='mock-org',
|
||||
org_version=ORG_SETTINGS_VERSION,
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
Role(
|
||||
id=1,
|
||||
name='admin',
|
||||
rank=1,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
User(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
current_org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_consents_to_analytics=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
OrgMember(
|
||||
org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
role_id=1,
|
||||
llm_api_key='mock-api-key',
|
||||
status='active',
|
||||
)
|
||||
)
|
||||
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id='mock-user-id',
|
||||
@@ -138,6 +93,13 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-10'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
UserSettings(
|
||||
keycloak_user_id='mock-user-id',
|
||||
user_consents_to_analytics=True,
|
||||
user_version=CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
ConversationWork(
|
||||
conversation_id='mock-conversation-id',
|
||||
@@ -146,6 +108,17 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
maintenance_task = MaintenanceTask(
|
||||
status=MaintenanceTaskStatus.PENDING,
|
||||
)
|
||||
maintenance_task.set_processor(
|
||||
UserVersionUpgradeProcessor(
|
||||
user_ids=['mock-user-id'],
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
session.add(maintenance_task)
|
||||
session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
"""Test for ResolverUserContext get_secrets conversion logic.
|
||||
|
||||
This test focuses on testing the actual ResolverUserContext implementation.
|
||||
"""
|
||||
|
||||
from types import MappingProxyType
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from enterprise.integrations.resolver_context import ResolverUserContext
|
||||
|
||||
# Import the real classes we want to test
|
||||
from openhands.integrations.provider import CustomSecret
|
||||
|
||||
# Import the SDK types we need for testing
|
||||
from openhands.sdk.conversation.secret_source import SecretSource, StaticSecret
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_saas_user_auth():
|
||||
"""Mock SaasUserAuth for testing."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resolver_context(mock_saas_user_auth):
|
||||
"""Create a ResolverUserContext instance for testing."""
|
||||
return ResolverUserContext(saas_user_auth=mock_saas_user_auth)
|
||||
|
||||
|
||||
def create_custom_secret(value: str, description: str = 'Test secret') -> CustomSecret:
|
||||
"""Helper to create CustomSecret instances."""
|
||||
return CustomSecret(secret=SecretStr(value), description=description)
|
||||
|
||||
|
||||
def create_secrets(custom_secrets_dict: dict[str, CustomSecret]) -> Secrets:
|
||||
"""Helper to create Secrets instances."""
|
||||
return Secrets(custom_secrets=MappingProxyType(custom_secrets_dict))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secrets_converts_custom_to_static(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_secrets correctly converts CustomSecret objects to StaticSecret objects."""
|
||||
# Arrange
|
||||
secrets = create_secrets(
|
||||
{
|
||||
'TEST_SECRET_1': create_custom_secret('secret_value_1'),
|
||||
'TEST_SECRET_2': create_custom_secret('secret_value_2'),
|
||||
}
|
||||
)
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(secret, StaticSecret) for secret in result.values())
|
||||
assert result['TEST_SECRET_1'].value.get_secret_value() == 'secret_value_1'
|
||||
assert result['TEST_SECRET_2'].value.get_secret_value() == 'secret_value_2'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secrets_with_special_characters(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that secret values with special characters are preserved during conversion."""
|
||||
# Arrange
|
||||
special_value = 'very_secret_password_123!@#$%^&*()'
|
||||
secrets = create_secrets({'SPECIAL_SECRET': create_custom_secret(special_value)})
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert isinstance(result['SPECIAL_SECRET'], StaticSecret)
|
||||
assert result['SPECIAL_SECRET'].value.get_secret_value() == special_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'secrets_input,expected_result',
|
||||
[
|
||||
(None, {}), # No secrets available
|
||||
(create_secrets({}), {}), # Empty custom secrets
|
||||
],
|
||||
)
|
||||
async def test_get_secrets_empty_cases(
|
||||
resolver_context, mock_saas_user_auth, secrets_input, expected_result
|
||||
):
|
||||
"""Test that get_secrets handles empty cases correctly."""
|
||||
# Arrange
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets_input
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_static_secret_is_valid_secret_source():
|
||||
"""Test that StaticSecret is a valid SecretSource for SDK validation."""
|
||||
# Arrange & Act
|
||||
static_secret = StaticSecret(value='test_secret_123')
|
||||
|
||||
# Assert
|
||||
assert isinstance(static_secret, StaticSecret)
|
||||
assert isinstance(static_secret, SecretSource)
|
||||
assert static_secret.value.get_secret_value() == 'test_secret_123'
|
||||
|
||||
|
||||
def test_custom_to_static_conversion():
|
||||
"""Test the complete conversion flow from CustomSecret to StaticSecret."""
|
||||
# Arrange
|
||||
secret_value = 'conversion_test_secret'
|
||||
custom_secret = create_custom_secret(secret_value, 'Conversion test')
|
||||
|
||||
# Act - simulate the conversion logic from the actual method
|
||||
extracted_value = custom_secret.secret.get_secret_value()
|
||||
static_secret = StaticSecret(value=extracted_value)
|
||||
|
||||
# Assert
|
||||
assert isinstance(static_secret, StaticSecret)
|
||||
assert isinstance(static_secret, SecretSource)
|
||||
assert static_secret.value.get_secret_value() == secret_value
|
||||
@@ -6,32 +6,22 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import BackgroundTasks, HTTPException, Request, status
|
||||
|
||||
# Import the actual StoredConversationMetadata from OpenHands core
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
from server.routes.event_webhook import (
|
||||
BatchMethod,
|
||||
BatchOperation,
|
||||
_get_session_api_key,
|
||||
_get_user_id,
|
||||
_parse_conversation_id_and_subpath,
|
||||
_process_batch_operations_background,
|
||||
on_batch_write,
|
||||
on_delete,
|
||||
on_write,
|
||||
)
|
||||
|
||||
# Mock the lazy import to return the actual class
|
||||
with patch(
|
||||
'storage.stored_conversation_metadata.StoredConversationMetadata',
|
||||
StoredConversationMetadata,
|
||||
):
|
||||
from server.routes.event_webhook import (
|
||||
BatchMethod,
|
||||
BatchOperation,
|
||||
_get_session_api_key,
|
||||
_get_user_id,
|
||||
_parse_conversation_id_and_subpath,
|
||||
_process_batch_operations_background,
|
||||
on_batch_write,
|
||||
on_delete,
|
||||
on_write,
|
||||
)
|
||||
from server.utils.conversation_callback_utils import (
|
||||
process_event,
|
||||
update_conversation_metadata,
|
||||
)
|
||||
from server.utils.conversation_callback_utils import (
|
||||
process_event,
|
||||
update_conversation_metadata,
|
||||
)
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
@@ -92,7 +82,7 @@ class TestGetUserId:
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
user_id = _get_user_id('mock-conversation-id')
|
||||
assert user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
assert user_id == 'mock-user-id'
|
||||
|
||||
def test_get_user_id_conversation_not_found(self, session_maker):
|
||||
"""Test getting user ID when conversation doesn't exist."""
|
||||
@@ -115,12 +105,10 @@ class TestGetSessionApiKey:
|
||||
return_value=[mock_agent_loop_info]
|
||||
)
|
||||
|
||||
api_key = await _get_session_api_key(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
|
||||
)
|
||||
api_key = await _get_session_api_key('user-123', 'conv-456')
|
||||
assert api_key == 'test-api-key'
|
||||
mock_manager.get_agent_loop_info.assert_called_once_with(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', filter_to_sids={'conv-456'}
|
||||
'user-123', filter_to_sids={'conv-456'}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -130,9 +118,7 @@ class TestGetSessionApiKey:
|
||||
mock_manager.get_agent_loop_info = AsyncMock(return_value=[])
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
await _get_session_api_key(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
|
||||
)
|
||||
await _get_session_api_key('user-123', 'conv-456')
|
||||
|
||||
|
||||
class TestProcessEvent:
|
||||
@@ -156,15 +142,10 @@ class TestProcessEvent:
|
||||
mock_event = MagicMock()
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
|
||||
mock_file_store.write.assert_called_once_with(
|
||||
'users/5594c7b6-f959-4b81-92e9-b09c206f5081/conversations/conv-456/events/event-1.json',
|
||||
'users/user-123/conversations/conv-456/events/event-1.json',
|
||||
json.dumps(content),
|
||||
)
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
@@ -196,19 +177,14 @@ class TestProcessEvent:
|
||||
)
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
|
||||
mock_file_store.write.assert_called_once()
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
mock_invoke_callbacks.assert_called_once_with('conv-456', mock_event)
|
||||
mock_update_working_seconds.assert_called_once()
|
||||
mock_event_store_class.assert_called_once_with(
|
||||
'conv-456', mock_file_store, '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
'conv-456', mock_file_store, 'user-123'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -236,12 +212,7 @@ class TestProcessEvent:
|
||||
mock_event.agent_state = 'running' # Set RUNNING state to skip the update
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
|
||||
mock_file_store.write.assert_called_once()
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
@@ -265,13 +236,10 @@ class TestUpdateConversationMetadata:
|
||||
'total_tokens': 1500,
|
||||
}
|
||||
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
|
||||
original_session_maker = callback_utils.session_maker
|
||||
|
||||
try:
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
update_conversation_metadata('mock-conversation-id', content)
|
||||
|
||||
# Verify the conversation was updated
|
||||
@@ -289,9 +257,6 @@ class TestUpdateConversationMetadata:
|
||||
assert conversation.completion_tokens == 500
|
||||
assert conversation.total_tokens == 1500
|
||||
assert isinstance(conversation.last_updated_at, datetime)
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
|
||||
def test_update_conversation_metadata_partial_fields(
|
||||
self, session_maker_with_minimal_fixtures
|
||||
@@ -299,13 +264,10 @@ class TestUpdateConversationMetadata:
|
||||
"""Test updating conversation metadata with only some fields."""
|
||||
content = {'accumulated_cost': 15.75, 'prompt_tokens': 2000}
|
||||
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
|
||||
original_session_maker = callback_utils.session_maker
|
||||
|
||||
try:
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
update_conversation_metadata('mock-conversation-id', content)
|
||||
|
||||
# Verify only specified fields were updated, others remain unchanged
|
||||
@@ -323,9 +285,6 @@ class TestUpdateConversationMetadata:
|
||||
# These should remain as original values from fixtures
|
||||
assert conversation.completion_tokens == 250
|
||||
assert conversation.total_tokens == 750
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
|
||||
def test_update_conversation_metadata_empty_content(
|
||||
self, session_maker_with_minimal_fixtures
|
||||
@@ -333,13 +292,10 @@ class TestUpdateConversationMetadata:
|
||||
"""Test updating conversation metadata with empty content."""
|
||||
content: dict[str, float] = {}
|
||||
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
|
||||
original_session_maker = callback_utils.session_maker
|
||||
|
||||
try:
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
update_conversation_metadata('mock-conversation-id', content)
|
||||
|
||||
# Verify only last_updated_at was changed
|
||||
@@ -358,9 +314,6 @@ class TestUpdateConversationMetadata:
|
||||
assert conversation.completion_tokens == 250
|
||||
assert conversation.total_tokens == 750
|
||||
assert isinstance(conversation.last_updated_at, datetime)
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
|
||||
|
||||
class TestOnDelete:
|
||||
@@ -391,31 +344,24 @@ class TestOnWrite:
|
||||
content = {'accumulated_cost': 20.0}
|
||||
mock_request.json.return_value = content
|
||||
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key:
|
||||
mock_get_api_key.return_value = 'correct-api-key'
|
||||
|
||||
original_session_maker = callback_utils.session_maker
|
||||
result = await on_write(
|
||||
'sessions/mock-conversation-id/metadata.json',
|
||||
mock_request,
|
||||
'correct-api-key',
|
||||
)
|
||||
|
||||
try:
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key:
|
||||
mock_get_api_key.return_value = 'correct-api-key'
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
|
||||
result = await on_write(
|
||||
'sessions/mock-conversation-id/metadata.json',
|
||||
mock_request,
|
||||
'correct-api-key',
|
||||
)
|
||||
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_write_events_success(
|
||||
@@ -623,38 +569,31 @@ class TestProcessBatchOperationsBackground:
|
||||
)
|
||||
]
|
||||
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key, patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
mock_get_api_key.return_value = 'correct-api-key'
|
||||
|
||||
original_session_maker = callback_utils.session_maker
|
||||
# Should not raise any exceptions
|
||||
await _process_batch_operations_background(batch_ops, 'correct-api-key')
|
||||
|
||||
try:
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key:
|
||||
mock_get_api_key.return_value = 'correct-api-key'
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
|
||||
# Should not raise any exceptions
|
||||
await _process_batch_operations_background(batch_ops, 'correct-api-key')
|
||||
|
||||
# Verify the conversation metadata was updated
|
||||
with session_maker_with_minimal_fixtures() as session:
|
||||
conversation = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id
|
||||
== 'mock-conversation-id'
|
||||
)
|
||||
.first()
|
||||
# Verify the conversation metadata was updated
|
||||
with session_maker_with_minimal_fixtures() as session:
|
||||
conversation = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id
|
||||
== 'mock-conversation-id'
|
||||
)
|
||||
assert conversation.accumulated_cost == 15.0
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
.first()
|
||||
)
|
||||
assert conversation.accumulated_cost == 15.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_batch_operations_events_success(
|
||||
@@ -705,27 +644,20 @@ class TestProcessBatchOperationsBackground:
|
||||
),
|
||||
]
|
||||
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key, patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
# First call succeeds, second fails
|
||||
mock_get_api_key.side_effect = ['correct-api-key', 'wrong-api-key']
|
||||
|
||||
original_session_maker = callback_utils.session_maker
|
||||
|
||||
try:
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key:
|
||||
# First call succeeds, second fails
|
||||
mock_get_api_key.side_effect = ['correct-api-key', 'wrong-api-key']
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
|
||||
# Should not raise exceptions, just log errors
|
||||
await _process_batch_operations_background(batch_ops, 'correct-api-key')
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
# Should not raise exceptions, just log errors
|
||||
await _process_batch_operations_background(batch_ops, 'correct-api-key')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_batch_operations_invalid_method_skipped(
|
||||
|
||||
@@ -1,371 +0,0 @@
|
||||
"""Tests for SaasSQLAppConversationInfoService.
|
||||
|
||||
This module tests the SAAS implementation of SQLAppConversationInfoService,
|
||||
focusing on user isolation, SAAS metadata handling, and multi-tenant functionality.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Import the SAAS service
|
||||
from enterprise.storage.saas_app_conversation_info_injector import (
|
||||
SaasSQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance for testing."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_with_user(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance with a user_id for testing."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111'),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_info() -> AppConversationInfo:
|
||||
"""Create a sample AppConversationInfo for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
sandbox_id='sandbox_123',
|
||||
selected_repository='https://github.com/test/repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Test Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[123, 456],
|
||||
llm_model='gpt-4',
|
||||
metrics=None,
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multiple_conversation_infos() -> list[AppConversationInfo]:
|
||||
"""Create multiple AppConversationInfo instances for testing."""
|
||||
base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
return [
|
||||
AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=None,
|
||||
sandbox_id=f'sandbox_{i}',
|
||||
selected_repository=f'https://github.com/test/repo{i}',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title=f'Test Conversation {i}',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[i * 100],
|
||||
llm_model='gpt-4',
|
||||
metrics=None,
|
||||
created_at=base_time.replace(hour=12 + i),
|
||||
updated_at=base_time.replace(hour=12 + i, minute=30),
|
||||
)
|
||||
for i in range(1, 6) # Create 5 conversations
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user1_context():
|
||||
"""Create user context for user1."""
|
||||
return SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user2_context():
|
||||
"""Create user context for user2."""
|
||||
return SpecifyUserContext(user_id='b2222222-2222-2222-2222-222222222222')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def saas_service_user1(mock_db_session, user1_context):
|
||||
"""Create a SaasSQLAppConversationInfoService instance for user1."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=mock_db_session, user_context=user1_context
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def saas_service_user2(mock_db_session, user2_context):
|
||||
"""Create a SaasSQLAppConversationInfoService instance for user2."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=mock_db_session, user_context=user2_context
|
||||
)
|
||||
|
||||
|
||||
class TestSaasSQLAppConversationInfoService:
|
||||
"""Test suite for SaasSQLAppConversationInfoService."""
|
||||
|
||||
def test_service_initialization(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
user1_context: SpecifyUserContext,
|
||||
):
|
||||
"""Test that the SAAS service is properly initialized."""
|
||||
assert saas_service_user1.user_context == user1_context
|
||||
assert saas_service_user1.db_session is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_context_isolation(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
saas_service_user2: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that different service instances have different user contexts."""
|
||||
user1_id = await saas_service_user1.user_context.get_user_id()
|
||||
user2_id = await saas_service_user2.user_context.get_user_id()
|
||||
|
||||
assert user1_id == 'a1111111-1111-1111-1111-111111111111'
|
||||
assert user2_id == 'b2222222-2222-2222-2222-222222222222'
|
||||
assert user1_id != user2_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_includes_user_filtering(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that _secure_select method includes user filtering."""
|
||||
# This test verifies that the _secure_select method exists and can be called
|
||||
# The actual SQL generation is tested implicitly through integration
|
||||
query = await saas_service_user1._secure_select()
|
||||
assert query is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_info_with_user_id_functionality(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that _to_info_with_user_id properly sets user_id from SAAS metadata."""
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
# Create mock metadata objects
|
||||
stored_metadata = MagicMock(spec=StoredConversationMetadata)
|
||||
stored_metadata.conversation_id = '12345678-1234-5678-1234-567812345678'
|
||||
stored_metadata.parent_conversation_id = None
|
||||
stored_metadata.title = 'Test Conversation'
|
||||
stored_metadata.sandbox_id = 'test-sandbox'
|
||||
stored_metadata.selected_repository = None
|
||||
stored_metadata.selected_branch = None
|
||||
stored_metadata.git_provider = None
|
||||
stored_metadata.trigger = None
|
||||
stored_metadata.pr_number = []
|
||||
stored_metadata.llm_model = None
|
||||
from datetime import datetime, timezone
|
||||
|
||||
stored_metadata.created_at = datetime.now(timezone.utc)
|
||||
stored_metadata.last_updated_at = datetime.now(timezone.utc)
|
||||
stored_metadata.accumulated_cost = 0.0
|
||||
stored_metadata.prompt_tokens = 0
|
||||
stored_metadata.completion_tokens = 0
|
||||
stored_metadata.total_tokens = 0
|
||||
stored_metadata.max_budget_per_task = None
|
||||
stored_metadata.cache_read_tokens = 0
|
||||
stored_metadata.cache_write_tokens = 0
|
||||
stored_metadata.reasoning_tokens = 0
|
||||
stored_metadata.context_window = 0
|
||||
stored_metadata.per_turn_token = 0
|
||||
|
||||
saas_metadata = MagicMock(spec=StoredConversationMetadataSaas)
|
||||
saas_metadata.user_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
saas_metadata.org_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
|
||||
# Test the _to_info_with_user_id method
|
||||
result = saas_service_user1._to_info_with_user_id(
|
||||
stored_metadata, saas_metadata
|
||||
)
|
||||
|
||||
# Verify that the user_id from SAAS metadata is used
|
||||
assert result.created_by_user_id == 'a1111111-1111-1111-1111-111111111111'
|
||||
assert result.title == 'Test Conversation'
|
||||
assert result.sandbox_id == 'test-sandbox'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(
|
||||
self,
|
||||
async_session: AsyncSession,
|
||||
multiple_conversation_infos: list[AppConversationInfo],
|
||||
):
|
||||
"""Test that user isolation works correctly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database session execute method to return mock users
|
||||
# This mock intercepts User queries and returns a mock user object
|
||||
# with user_id and org_id the same as the user_id_uuid from the query
|
||||
original_execute = async_session.execute
|
||||
|
||||
async def mock_execute(query):
|
||||
query_str = str(query)
|
||||
|
||||
# Check if this is a User query
|
||||
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
|
||||
# Extract the UUID from the query parameters
|
||||
# The query will have bound parameters, we need to get the UUID value
|
||||
if hasattr(query, 'compile'):
|
||||
try:
|
||||
compiled = query.compile(compile_kwargs={'literal_binds': True})
|
||||
query_with_params = str(compiled)
|
||||
|
||||
# Extract UUID from the query string
|
||||
import re
|
||||
|
||||
# Try both formats: with dashes and without dashes
|
||||
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
||||
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
|
||||
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_with_dashes, query_with_params
|
||||
)
|
||||
if not uuid_match:
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_without_dashes, query_with_params
|
||||
)
|
||||
|
||||
if uuid_match:
|
||||
user_id_str = uuid_match.group(0)
|
||||
# If the UUID doesn't have dashes, add them
|
||||
if len(user_id_str) == 32 and '-' not in user_id_str:
|
||||
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
|
||||
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
|
||||
# Create a mock user with user_id and org_id the same as user_id_uuid
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.id = user_id_uuid
|
||||
mock_user.current_org_id = user_id_uuid
|
||||
|
||||
# Create a mock result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_user
|
||||
return mock_result
|
||||
except Exception:
|
||||
# If there's any error in parsing, fall back to original execute
|
||||
pass
|
||||
|
||||
# For all other queries, use the original execute method
|
||||
return await original_execute(query)
|
||||
|
||||
# Apply the mock
|
||||
async_session.execute = mock_execute
|
||||
|
||||
# Create services for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='a1111111-1111-1111-1111-111111111111'
|
||||
),
|
||||
)
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='b2222222-2222-2222-2222-222222222222'
|
||||
),
|
||||
)
|
||||
|
||||
# Create conversations for different users
|
||||
user1_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User 1 Conversation',
|
||||
)
|
||||
|
||||
user2_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='b2222222-2222-2222-2222-222222222222',
|
||||
sandbox_id='sandbox_user2',
|
||||
title='User 2 Conversation',
|
||||
)
|
||||
|
||||
# Save conversations
|
||||
await user1_service.save_app_conversation_info(user1_info)
|
||||
await user2_service.save_app_conversation_info(user2_info)
|
||||
|
||||
# User 1 should only see their conversation
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 1
|
||||
assert (
|
||||
user1_page.items[0].created_by_user_id
|
||||
== 'a1111111-1111-1111-1111-111111111111'
|
||||
)
|
||||
|
||||
# User 2 should only see their conversation
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 1
|
||||
assert (
|
||||
user2_page.items[0].created_by_user_id
|
||||
== 'b2222222-2222-2222-2222-222222222222'
|
||||
)
|
||||
|
||||
# User 1 should not be able to get user 2's conversation
|
||||
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
||||
assert user2_from_user1 is None
|
||||
|
||||
# User 2 should not be able to get user 1's conversation
|
||||
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
||||
assert user1_from_user2 is None
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
@@ -19,14 +19,6 @@ def mock_session_maker(mock_session):
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock user with org_id."""
|
||||
user = MagicMock()
|
||||
user.current_org_id = 'test-org-123'
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_store(mock_session_maker):
|
||||
return ApiKeyStore(mock_session_maker)
|
||||
@@ -39,13 +31,11 @@ def test_generate_api_key(api_key_store):
|
||||
assert len(key) == 32
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
def test_create_api_key(api_key_store, mock_session):
|
||||
"""Test creating an API key."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
name = 'Test Key'
|
||||
mock_get_user.return_value = mock_user
|
||||
api_key_store.generate_api_key = MagicMock(return_value='test-api-key')
|
||||
|
||||
# Execute
|
||||
@@ -53,15 +43,10 @@ def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
|
||||
# Verify
|
||||
assert result == 'test-api-key'
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
api_key_store.generate_api_key.assert_called_once()
|
||||
|
||||
# Verify the ApiKey was created with the correct org_id
|
||||
added_api_key = mock_session.add.call_args[0][0]
|
||||
assert added_api_key.org_id == mock_user.current_org_id
|
||||
|
||||
|
||||
def test_validate_api_key_valid(api_key_store, mock_session):
|
||||
"""Test validating a valid API key."""
|
||||
@@ -173,12 +158,10 @@ def test_delete_api_key_by_id(api_key_store, mock_session):
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
def test_list_api_keys(api_key_store, mock_session):
|
||||
"""Test listing API keys for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
now = datetime.now(UTC)
|
||||
mock_key1 = MagicMock()
|
||||
mock_key1.id = 1
|
||||
@@ -194,17 +177,15 @@ def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
mock_key2.last_used_at = None
|
||||
mock_key2.expires_at = None
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_key1, mock_key2]
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
mock_key1,
|
||||
mock_key2,
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.list_api_keys(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert len(result) == 2
|
||||
assert result[0]['id'] == 1
|
||||
assert result[0]['name'] == 'Key 1'
|
||||
@@ -217,59 +198,3 @@ def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
assert result[1]['created_at'] == now
|
||||
assert result[1]['last_used_at'] is None
|
||||
assert result[1]['expires_at'] is None
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
"""Test retrieving MCP API key for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
mock_mcp_key = MagicMock()
|
||||
mock_mcp_key.name = 'MCP_API_KEY'
|
||||
mock_mcp_key.key = 'mcp-test-key'
|
||||
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result == 'mcp-test-key'
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key_not_found(
|
||||
mock_get_user, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test retrieving MCP API key when none exists."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result is None
|
||||
|
||||
@@ -127,7 +127,6 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -141,15 +140,6 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = None
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = False
|
||||
|
||||
@@ -171,19 +161,20 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.set_response_cookie') as mock_set_cookie,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
):
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
# Mock the session and query results
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
# Mock user settings with accepted_tos
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -235,20 +226,20 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
),
|
||||
patch('server.routes.auth.KEYCLOAK_REALM_NAME', 'test-realm'),
|
||||
patch('server.routes.auth.KEYCLOAK_CLIENT_ID', 'test-client'),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
):
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
# Mock the session and query results
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
|
||||
# Mock user settings with accepted_tos
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from fastapi import HTTPException, Request, status
|
||||
from httpx import Response
|
||||
from server.routes import billing
|
||||
from httpx import HTTPStatusError, Response
|
||||
from integrations.stripe_service import has_payment_method
|
||||
from server.routes.billing import (
|
||||
CreateBillingSessionResponse,
|
||||
CreateCheckoutSessionRequest,
|
||||
GetCreditsResponse,
|
||||
cancel_callback,
|
||||
cancel_subscription,
|
||||
create_checkout_session,
|
||||
create_customer_setup_session,
|
||||
create_subscription_checkout_session,
|
||||
get_credits,
|
||||
has_payment_method,
|
||||
success_callback,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.stripe_customer import Base as StripeCustomerBase
|
||||
|
||||
|
||||
@@ -78,31 +78,29 @@ def mock_subscription_request():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_lite_llm_error():
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await get_credits('mock_user')
|
||||
mock_request = Request(scope={'type': 'http', 'state': {'user_id': 'mock_user'}})
|
||||
|
||||
mock_response = Response(
|
||||
status_code=500, json={'error': 'Internal Server Error'}, request=MagicMock()
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
with patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'):
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
with pytest.raises(HTTPStatusError) as exc_info:
|
||||
await get_credits(mock_request)
|
||||
assert (
|
||||
exc_info.value.response.status_code
|
||||
== status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_success():
|
||||
mock_response = Response(
|
||||
status_code=200,
|
||||
json={
|
||||
'user_info': {
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
}
|
||||
},
|
||||
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
|
||||
request=MagicMock(),
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
@@ -111,22 +109,24 @@ async def test_get_credits_success():
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch('httpx.AsyncClient', return_value=mock_client),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
},
|
||||
),
|
||||
):
|
||||
result = await get_credits('mock_user')
|
||||
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.first.return_value = MagicMock(
|
||||
billing_margin=4
|
||||
)
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
assert isinstance(result, GetCreditsResponse)
|
||||
assert result.credits == Decimal('74.50') # 100.00 - 25.50 = 74.50
|
||||
result = await get_credits('mock_user')
|
||||
|
||||
assert isinstance(result, GetCreditsResponse)
|
||||
assert result.credits == Decimal(
|
||||
'74.50'
|
||||
) # 100.00 - 25.50 = 74.50 (no billing margin applied)
|
||||
mock_client.__aenter__.return_value.get.assert_called_once_with(
|
||||
'https://llm-proxy.app.all-hands.dev/user/info?user_id=mock_user',
|
||||
headers={'x-goog-api-key': None},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -139,9 +139,6 @@ async def test_create_checkout_session_stripe_error(
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = uuid.uuid4()
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
pytest.raises(Exception, match='Stripe API Error'),
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
@@ -153,10 +150,6 @@ async def test_create_checkout_session_stripe_error(
|
||||
AsyncMock(side_effect=Exception('Stripe API Error')),
|
||||
),
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
@@ -182,10 +175,6 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org_id = uuid.uuid4()
|
||||
mock_org.id = mock_org_id
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
patch(
|
||||
@@ -194,10 +183,6 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
@@ -269,6 +254,7 @@ async def test_success_callback_stripe_incomplete():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
@@ -296,33 +282,44 @@ async def test_success_callback_success():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
mock_lite_llm_response = Response(
|
||||
status_code=200,
|
||||
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
|
||||
request=MagicMock(),
|
||||
)
|
||||
mock_lite_llm_update_response = Response(
|
||||
status_code=200, json={}, request=MagicMock()
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
},
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
|
||||
) as mock_update_budget,
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_user_settings = MagicMock(billing_margin=None)
|
||||
mock_db_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_user_settings
|
||||
)
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete', amount_subtotal=2500, customer='mock_customer_id'
|
||||
status='complete',
|
||||
amount_subtotal=2500,
|
||||
) # $25.00 in cents
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__.return_value.get.return_value = (
|
||||
mock_lite_llm_response
|
||||
)
|
||||
mock_client_instance.__aenter__.return_value.post.return_value = (
|
||||
mock_lite_llm_update_response
|
||||
)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
response = await success_callback('test_session_id', mock_request)
|
||||
|
||||
assert response.status_code == 302
|
||||
@@ -332,14 +329,18 @@ async def test_success_callback_success():
|
||||
)
|
||||
|
||||
# Verify LiteLLM API calls
|
||||
mock_update_budget.assert_called_once_with(
|
||||
'mock_org_id',
|
||||
125.0, # 100 + (25.00 from Stripe)
|
||||
mock_client_instance.__aenter__.return_value.get.assert_called_once()
|
||||
mock_client_instance.__aenter__.return_value.post.assert_called_once_with(
|
||||
'https://llm-proxy.app.all-hands.dev/user/update',
|
||||
headers={'x-goog-api-key': None},
|
||||
json={
|
||||
'user_id': 'mock_user',
|
||||
'max_budget': 125,
|
||||
}, # 100 + (25.00 from Stripe)
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'completed'
|
||||
assert mock_billing_session.price == 25.0
|
||||
mock_db_session.merge.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@@ -353,27 +354,27 @@ async def test_success_callback_lite_llm_error():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete', amount_subtotal=2500
|
||||
status='complete', amount_total=2500
|
||||
)
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__.return_value.get.side_effect = Exception(
|
||||
'LiteLLM API Error'
|
||||
)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback('test_session_id', mock_request)
|
||||
|
||||
@@ -397,8 +398,7 @@ async def test_cancel_callback_session_not_found():
|
||||
response = await cancel_callback('test_session_id', mock_request)
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify no database updates occurred
|
||||
@@ -424,8 +424,7 @@ async def test_cancel_callback_success():
|
||||
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
@@ -437,67 +436,314 @@ async def test_cancel_callback_success():
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_with_payment_method():
|
||||
"""Test has_payment_method returns True when user has a payment method."""
|
||||
|
||||
mock_has_payment_method = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
mock_has_payment_method,
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[MagicMock()])),
|
||||
) as mock_list_payment_methods,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
MagicMock(stripe_customer_id='cus_test123')
|
||||
)
|
||||
|
||||
result = await has_payment_method('mock_user')
|
||||
assert result is True
|
||||
mock_has_payment_method.assert_called_once_with('mock_user')
|
||||
mock_list_payment_methods.assert_called_once_with('cus_test123')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_without_payment_method():
|
||||
"""Test has_payment_method returns False when user has no payment method."""
|
||||
mock_has_payment_method = AsyncMock(return_value=False)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
mock_has_payment_method,
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
) as mock_list_payment_methods,
|
||||
):
|
||||
mock_has_payment_method.return_value = False
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
MagicMock(stripe_customer_id='cus_test123')
|
||||
)
|
||||
|
||||
result = await has_payment_method('mock_user')
|
||||
assert result is False
|
||||
mock_has_payment_method.assert_called_once_with('mock_user')
|
||||
mock_list_payment_methods.assert_called_once_with('cus_test123')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_customer_setup_session_success():
|
||||
"""Test successful creation of customer setup session."""
|
||||
mock_request = Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
'path': '/api/billing/create-customer-setup-session',
|
||||
'server': ('test.com', 80),
|
||||
'headers': [],
|
||||
}
|
||||
)
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
async def test_cancel_subscription_success():
|
||||
"""Test successful subscription cancellation."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
mock_customer_info = {'customer_id': 'mock-customer-id', 'org_id': 'mock-org-id'}
|
||||
mock_session = MagicMock()
|
||||
mock_session.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_create = AsyncMock(return_value=mock_session)
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
# Mock Stripe subscription response
|
||||
mock_stripe_subscription = MagicMock()
|
||||
mock_stripe_subscription.cancel_at_period_end = True
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer_by_user_id',
|
||||
AsyncMock(return_value=mock_customer_info),
|
||||
'stripe.Subscription.modify_async',
|
||||
AsyncMock(return_value=mock_stripe_subscription),
|
||||
) as mock_stripe_modify,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function
|
||||
result = await cancel_subscription('test_user')
|
||||
|
||||
# Verify Stripe API was called
|
||||
mock_stripe_modify.assert_called_once_with(
|
||||
'sub_test123', cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Verify database was updated
|
||||
assert mock_subscription_access.cancelled_at is not None
|
||||
mock_session.merge.assert_called_once_with(mock_subscription_access)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify response
|
||||
assert result.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_no_active_subscription():
|
||||
"""Test cancellation when no active subscription exists."""
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Setup mock session with no subscription found
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'No active subscription found' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_missing_stripe_id():
|
||||
"""Test cancellation when subscription has no Stripe ID."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock subscription without Stripe ID
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id=None, # Missing Stripe ID
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'missing Stripe subscription ID' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_stripe_error():
|
||||
"""Test cancellation when Stripe API fails."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Subscription.modify_async',
|
||||
AsyncMock(side_effect=stripe.StripeError('API Error')),
|
||||
),
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to cancel subscription' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_duplicate_prevention(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription when user already has active subscription raises error."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
result = await create_customer_setup_session(mock_request, 'mock_user')
|
||||
# Setup mock session to return existing active subscription
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
assert isinstance(result, billing.CreateBillingSessionResponse)
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert (
|
||||
'user already has an active subscription'
|
||||
in str(exc_info.value.detail).lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_allows_after_cancellation(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription is allowed when previous subscription was cancelled."""
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session_obj.id = 'test_session_id'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer',
|
||||
AsyncMock(return_value='cus_test123'),
|
||||
),
|
||||
patch(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(return_value=mock_session_obj),
|
||||
),
|
||||
patch(
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session - the query should return None because cancelled subscriptions are filtered out
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
||||
|
||||
# Verify Stripe session creation parameters
|
||||
mock_create.assert_called_once_with(
|
||||
customer='mock-customer-id',
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url='http://test.com/?free_credits=success',
|
||||
cancel_url='http://test.com/',
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_success_no_existing(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test successful subscription creation when no existing subscription."""
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session_obj.id = 'test_session_id'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer',
|
||||
AsyncMock(return_value='cus_test123'),
|
||||
),
|
||||
patch(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(return_value=mock_session_obj),
|
||||
),
|
||||
patch(
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return no existing subscription
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
||||
|
||||
@@ -3,29 +3,14 @@ Tests for ConversationCallbackProcessor and ConversationCallback models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
# Import the actual StoredConversationMetadata from OpenHands core
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
|
||||
# Mock the lazy import to return the actual class
|
||||
with patch(
|
||||
'storage.stored_conversation_metadata.StoredConversationMetadata',
|
||||
StoredConversationMetadata,
|
||||
):
|
||||
from storage.conversation_callback import (
|
||||
CallbackStatus,
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
from storage.conversation_callback import (
|
||||
CallbackStatus,
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
@@ -95,22 +80,15 @@ class TestConversationCallback:
|
||||
"""Create a test conversation metadata record."""
|
||||
with session_maker() as session:
|
||||
metadata = StoredConversationMetadata(
|
||||
conversation_id='test_conversation_123'
|
||||
)
|
||||
metadata_saas = StoredConversationMetadataSaas(
|
||||
conversation_id='test_conversation_123',
|
||||
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
conversation_id='test_conversation_123', user_id='test_user_456'
|
||||
)
|
||||
session.add(metadata)
|
||||
session.add(metadata_saas)
|
||||
session.commit()
|
||||
session.refresh(metadata)
|
||||
yield metadata
|
||||
|
||||
# Cleanup
|
||||
session.delete(metadata)
|
||||
session.delete(metadata_saas)
|
||||
session.commit()
|
||||
|
||||
def test_callback_creation(self, conversation_metadata, session_maker):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from unittest import TestCase, mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.github.github_view import GithubFactory, GithubIssue, get_oh_labels
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.types import UserData
|
||||
@@ -115,10 +114,8 @@ class TestGithubV1ConversationRouting(TestCase):
|
||||
title='Test Issue',
|
||||
description='Test issue description',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.github.github_view.get_user_v1_enabled_setting')
|
||||
@patch.object(GithubIssue, '_create_v0_conversation')
|
||||
@patch.object(GithubIssue, '_create_v1_conversation')
|
||||
@@ -147,7 +144,6 @@ class TestGithubV1ConversationRouting(TestCase):
|
||||
)
|
||||
mock_create_v1.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.github.github_view.get_user_v1_enabled_setting')
|
||||
@patch.object(GithubIssue, '_create_v0_conversation')
|
||||
@patch.object(GithubIssue, '_create_v1_conversation')
|
||||
@@ -176,7 +172,6 @@ class TestGithubV1ConversationRouting(TestCase):
|
||||
)
|
||||
mock_create_v0.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.github.github_view.get_user_v1_enabled_setting')
|
||||
@patch.object(GithubIssue, '_create_v0_conversation')
|
||||
@patch.object(GithubIssue, '_create_v1_conversation')
|
||||
|
||||
@@ -1,650 +0,0 @@
|
||||
"""
|
||||
Unit tests for LiteLlmManager class.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
from server.constants import (
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
|
||||
class TestLiteLlmManager:
|
||||
"""Test cases for LiteLlmManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create a mock Settings object."""
|
||||
settings = Settings()
|
||||
settings.agent = 'TestAgent'
|
||||
settings.llm_model = 'test-model'
|
||||
settings.llm_api_key = SecretStr('test-key')
|
||||
settings.llm_base_url = 'http://test.com'
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_settings(self):
|
||||
"""Create a mock UserSettings object."""
|
||||
user_settings = UserSettings()
|
||||
user_settings.agent = 'TestAgent'
|
||||
user_settings.llm_model = 'test-model'
|
||||
user_settings.llm_api_key = SecretStr('test-key')
|
||||
user_settings.llm_base_url = 'http://test.com'
|
||||
return user_settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_client(self):
|
||||
"""Create a mock HTTP client."""
|
||||
client = AsyncMock(spec=httpx.AsyncClient)
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Create a mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.text = 'Success'
|
||||
response.json.return_value = {'key': 'test-api-key'}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_team_response(self):
|
||||
"""Create a mock team response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.json.return_value = {
|
||||
'team_memberships': [
|
||||
{
|
||||
'user_id': 'test-user-id',
|
||||
'team_id': 'test-org-id',
|
||||
'max_budget': 100.0,
|
||||
}
|
||||
]
|
||||
}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_response(self):
|
||||
"""Create a mock user response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.json.return_value = {
|
||||
'user_info': {
|
||||
'max_budget': 50.0,
|
||||
'spend': 10.0,
|
||||
}
|
||||
}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_key_info_response(self):
|
||||
"""Create a mock key info response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.json.return_value = {
|
||||
'info': {
|
||||
'max_budget': 100.0,
|
||||
'spend': 25.0,
|
||||
}
|
||||
}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_missing_config(self, mock_settings):
|
||||
"""Test create_entries when LiteLLM config is missing."""
|
||||
with patch.dict(os.environ, {'LITE_LLM_API_KEY': '', 'LITE_LLM_API_URL': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_local_deployment(self, mock_settings):
|
||||
"""Test create_entries in local deployment mode."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': '1'}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_cloud_deployment(self, mock_settings, mock_response):
|
||||
"""Test create_entries in cloud deployment mode."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert (
|
||||
result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
)
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify API calls were made
|
||||
assert (
|
||||
mock_client.post.call_count == 4
|
||||
) # create_team, create_user, add_user_to_team, generate_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_missing_config(self, mock_user_settings):
|
||||
"""Test migrate_entries when LiteLLM config is missing."""
|
||||
with patch.dict(os.environ, {'LITE_LLM_API_KEY': '', 'LITE_LLM_API_URL': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_local_deployment(self, mock_user_settings):
|
||||
"""Test migrate_entries in local deployment mode."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': '1'}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_no_user_found(self, mock_user_settings):
|
||||
"""Test migrate_entries when user is not found."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
# Mock the _get_user method directly to return None
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_user', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
mock_get_user.return_value = None
|
||||
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_already_migrated(
|
||||
self, mock_user_settings, mock_user_response
|
||||
):
|
||||
"""Test migrate_entries when user is already migrated (no max_budget)."""
|
||||
mock_user_response.json.return_value = {
|
||||
'user_info': {
|
||||
'max_budget': None, # Already migrated
|
||||
'spend': 10.0,
|
||||
}
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.get.return_value = mock_user_response
|
||||
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_successful_migration(
|
||||
self, mock_user_settings, mock_user_response, mock_response
|
||||
):
|
||||
"""Test successful migrate_entries operation."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.get.return_value = mock_user_response
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify migration steps were called
|
||||
assert (
|
||||
mock_client.post.call_count == 4
|
||||
) # create_team, update_user, add_user_to_team, update_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_team_and_users_budget_missing_config(self):
|
||||
"""Test update_team_and_users_budget when LiteLLM config is missing."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
# Should not raise an exception, just return early
|
||||
await LiteLlmManager.update_team_and_users_budget('test-team-id', 100.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_team_and_users_budget_successful(
|
||||
self, mock_team_response, mock_response
|
||||
):
|
||||
"""Test successful update_team_and_users_budget operation."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.get.return_value = mock_team_response
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
'test-team-id', 100.0
|
||||
)
|
||||
|
||||
# Verify update_team and update_user_in_team were called
|
||||
assert (
|
||||
mock_client.post.call_count == 2
|
||||
) # update_team, update_user_in_team
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _create_team operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_team(
|
||||
mock_http_client, 'test-alias', 'test-team-id', 100.0
|
||||
)
|
||||
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/team/new' in call_args[0]
|
||||
assert call_args[1]['json']['team_id'] == 'test-team-id'
|
||||
assert call_args[1]['json']['team_alias'] == 'test-alias'
|
||||
assert call_args[1]['json']['max_budget'] == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_already_exists(self, mock_http_client):
|
||||
"""Test _create_team when team already exists."""
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 400
|
||||
error_response.text = 'Team already exists. Please use a different team id'
|
||||
mock_http_client.post.return_value = error_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch.object(
|
||||
LiteLlmManager, '_update_team', new_callable=AsyncMock
|
||||
) as mock_update:
|
||||
await LiteLlmManager._create_team(
|
||||
mock_http_client, 'test-alias', 'test-team-id', 100.0
|
||||
)
|
||||
|
||||
mock_update.assert_called_once_with(
|
||||
mock_http_client, 'test-team-id', 'test-alias', 100.0
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_error(self, mock_http_client):
|
||||
"""Test _create_team with unexpected error."""
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 500
|
||||
error_response.text = 'Internal server error'
|
||||
error_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
'Server error', request=MagicMock(), response=error_response
|
||||
)
|
||||
mock_http_client.post.return_value = error_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await LiteLlmManager._create_team(
|
||||
mock_http_client, 'test-alias', 'test-team-id', 100.0
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_team_success(self, mock_http_client, mock_team_response):
|
||||
"""Test successful _get_team operation."""
|
||||
mock_http_client.get.return_value = mock_team_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._get_team(
|
||||
mock_http_client, 'test-team-id'
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert 'team_memberships' in result
|
||||
mock_http_client.get.assert_called_once_with(
|
||||
'http://test.com/team/info?team_id=test-team-id'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _create_user operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/user/new' in call_args[0]
|
||||
assert call_args[1]['json']['user_email'] == 'test@example.com'
|
||||
assert call_args[1]['json']['user_id'] == 'test-user-id'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_duplicate_email(self, mock_http_client, mock_response):
|
||||
"""Test _create_user with duplicate email handling."""
|
||||
# First call fails with duplicate email
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 400
|
||||
error_response.text = 'duplicate email'
|
||||
|
||||
# Second call succeeds
|
||||
mock_http_client.post.side_effect = [error_response, mock_response]
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
assert mock_http_client.post.call_count == 2
|
||||
# Second call should have None email
|
||||
second_call_args = mock_http_client.post.call_args_list[1]
|
||||
assert second_call_args[1]['json']['user_email'] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_key_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _generate_key operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._generate_key(
|
||||
mock_http_client,
|
||||
'test-user-id',
|
||||
'test-team-id',
|
||||
'test-alias',
|
||||
{'test': 'metadata'},
|
||||
)
|
||||
|
||||
assert result == 'test-api-key'
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/key/generate' in call_args[0]
|
||||
assert call_args[1]['json']['user_id'] == 'test-user-id'
|
||||
assert call_args[1]['json']['team_id'] == 'test-team-id'
|
||||
assert call_args[1]['json']['key_alias'] == 'test-alias'
|
||||
assert call_args[1]['json']['metadata'] == {'test': 'metadata'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_key_info_success(self, mock_http_client, mock_key_info_response):
|
||||
"""Test successful _get_key_info operation."""
|
||||
mock_http_client.get.return_value = mock_key_info_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch('storage.user_store.UserStore') as mock_user_store:
|
||||
# Mock user with org member
|
||||
mock_user = MagicMock()
|
||||
mock_org_member = MagicMock()
|
||||
mock_org_member.org_id = 'test-ord-id'
|
||||
mock_org_member.llm_api_key = 'test-api-key'
|
||||
mock_user.org_members = [mock_org_member]
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
|
||||
result = await LiteLlmManager._get_key_info(
|
||||
mock_http_client, 'test-ord-id', 'test-user-id'
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result['key_max_budget'] == 100.0
|
||||
assert result['key_spend'] == 25.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_key_info_no_user(self, mock_http_client):
|
||||
"""Test _get_key_info when user is not found."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch('storage.user_store.UserStore') as mock_user_store:
|
||||
mock_user_store.get_user_by_id.return_value = None
|
||||
|
||||
result = await LiteLlmManager._get_key_info(
|
||||
mock_http_client, 'test-ord-id', 'test-user-id'
|
||||
)
|
||||
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _delete_key operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._delete_key(mock_http_client, 'test-key-id')
|
||||
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/key/delete' in call_args[0]
|
||||
assert call_args[1]['json']['keys'] == ['test-key-id']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key_not_found(self, mock_http_client):
|
||||
"""Test _delete_key when key is not found (404 error)."""
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 404
|
||||
error_response.text = 'Key not found'
|
||||
mock_http_client.post.return_value = error_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
# Should not raise an exception for 404
|
||||
await LiteLlmManager._delete_key(mock_http_client, 'test-key-id')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_http_client_decorator(self):
|
||||
"""Test the with_http_client decorator functionality."""
|
||||
|
||||
# Create a mock internal function
|
||||
async def mock_internal_fn(client, arg1, arg2, kwarg1=None):
|
||||
return f'client={type(client).__name__}, arg1={arg1}, arg2={arg2}, kwarg1={kwarg1}'
|
||||
|
||||
# Apply the decorator
|
||||
decorated_fn = LiteLlmManager.with_http_client(mock_internal_fn)
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
result = await decorated_fn('test1', 'test2', kwarg1='test3')
|
||||
|
||||
# Verify the client was injected as the first argument
|
||||
assert 'client=AsyncMock' in result
|
||||
assert 'arg1=test1' in result
|
||||
assert 'arg2=test2' in result
|
||||
assert 'kwarg1=test3' in result
|
||||
|
||||
def test_public_methods_exist(self):
|
||||
"""Test that all public wrapper methods exist and are properly decorated."""
|
||||
public_methods = [
|
||||
'create_team',
|
||||
'get_team',
|
||||
'update_team',
|
||||
'create_user',
|
||||
'get_user',
|
||||
'update_user',
|
||||
'delete_user',
|
||||
'add_user_to_team',
|
||||
'get_user_team_info',
|
||||
'update_user_in_team',
|
||||
'generate_key',
|
||||
'get_key_info',
|
||||
'delete_key',
|
||||
]
|
||||
|
||||
for method_name in public_methods:
|
||||
assert hasattr(LiteLlmManager, method_name)
|
||||
method = getattr(LiteLlmManager, method_name)
|
||||
assert callable(method)
|
||||
# The methods are created by the with_http_client decorator, so they're functions
|
||||
# We can verify they exist and are callable, which is the important part
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_missing_config_all_methods(self):
|
||||
"""Test that all methods handle missing configuration gracefully."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
mock_client = AsyncMock()
|
||||
|
||||
# Test all private methods that check for config
|
||||
await LiteLlmManager._create_team(
|
||||
mock_client, 'alias', 'team_id', 100.0
|
||||
)
|
||||
await LiteLlmManager._update_team(
|
||||
mock_client, 'team_id', 'alias', 100.0
|
||||
)
|
||||
await LiteLlmManager._create_user(mock_client, 'email', 'user_id')
|
||||
await LiteLlmManager._update_user(mock_client, 'user_id')
|
||||
await LiteLlmManager._delete_user(mock_client, 'user_id')
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
mock_client, 'user_id', 'team_id', 100.0
|
||||
)
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
mock_client, 'user_id', 'team_id', 100.0
|
||||
)
|
||||
await LiteLlmManager._delete_key(mock_client, 'key_id')
|
||||
|
||||
result1 = await LiteLlmManager._get_team(mock_client, 'team_id')
|
||||
result2 = await LiteLlmManager._get_user(mock_client, 'user_id')
|
||||
result3 = await LiteLlmManager._generate_key(
|
||||
mock_client, 'user_id', 'team_id', 'alias', {}
|
||||
)
|
||||
result4 = await LiteLlmManager._get_user_team_info(
|
||||
mock_client, 'user_id', 'team_id'
|
||||
)
|
||||
result5 = await LiteLlmManager._get_key_info(
|
||||
mock_client, 'test-ord-id', 'user_id'
|
||||
)
|
||||
|
||||
# Methods that return None when config is missing
|
||||
assert result1 is None
|
||||
assert result2 is None
|
||||
assert result3 is None
|
||||
assert result4 is None
|
||||
assert result5 is None
|
||||
|
||||
# Verify no HTTP calls were made
|
||||
mock_client.get.assert_not_called()
|
||||
mock_client.post.assert_not_called()
|
||||
@@ -1,70 +0,0 @@
|
||||
"""
|
||||
Test that the models are correctly defined.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def test_user_model(session_maker):
|
||||
"""Test that the User model works correctly."""
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test_org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
# Create a test user
|
||||
test_user_id = uuid4()
|
||||
user = User(id=test_user_id, current_org_id=org.id, language='en')
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
# Create org_member relationship
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=1,
|
||||
llm_api_key='test-api-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
|
||||
# Query the user
|
||||
queried_user = session.query(User).filter(User.id == test_user_id).first()
|
||||
assert queried_user is not None
|
||||
assert queried_user.language == 'en'
|
||||
|
||||
# Query the org
|
||||
queried_org = session.query(Org).filter(Org.id == org.id).first()
|
||||
assert queried_org is not None
|
||||
assert queried_org.name == 'test_org'
|
||||
|
||||
# Query the org_member relationship
|
||||
queried_org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org.id, OrgMember.user_id == user.id)
|
||||
.first()
|
||||
)
|
||||
assert queried_org_member is not None
|
||||
assert queried_org_member.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
@@ -1,253 +0,0 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock the database module before importing OrgMemberStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
|
||||
def test_get_org_members(session_maker):
|
||||
# Test getting org_members by org ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user1 = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
user2 = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user1, user2, role])
|
||||
session.flush()
|
||||
|
||||
org_member1 = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user1.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-1',
|
||||
status='active',
|
||||
)
|
||||
org_member2 = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user2.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-2',
|
||||
status='active',
|
||||
)
|
||||
session.add_all([org_member1, org_member2])
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_members = OrgMemberStore.get_org_members(org_id)
|
||||
assert len(org_members) == 2
|
||||
api_keys = [om.llm_api_key.get_secret_value() for om in org_members]
|
||||
assert 'test-key-1' in api_keys
|
||||
assert 'test-key-2' in api_keys
|
||||
|
||||
|
||||
def test_get_user_orgs(session_maker):
|
||||
# Test getting org_members by user ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org1 = Org(name='test-org-1')
|
||||
org2 = Org(name='test-org-2')
|
||||
session.add_all([org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org1.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member1 = OrgMember(
|
||||
org_id=org1.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-1',
|
||||
status='active',
|
||||
)
|
||||
org_member2 = OrgMember(
|
||||
org_id=org2.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-2',
|
||||
status='active',
|
||||
)
|
||||
session.add_all([org_member1, org_member2])
|
||||
session.commit()
|
||||
user_id = user.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_members = OrgMemberStore.get_user_orgs(user_id)
|
||||
assert len(org_members) == 2
|
||||
api_keys = [ou.llm_api_key.get_secret_value() for ou in org_members]
|
||||
assert 'test-key-1' in api_keys
|
||||
assert 'test-key-2' in api_keys
|
||||
|
||||
|
||||
def test_get_org_member(session_maker):
|
||||
# Test getting org_member by org and user ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
|
||||
assert retrieved_org_member is not None
|
||||
assert retrieved_org_member.org_id == org_id
|
||||
assert retrieved_org_member.user_id == user_id
|
||||
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key'
|
||||
|
||||
|
||||
def test_add_user_to_org(session_maker):
|
||||
# Test adding a user to an org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
role_id = role.id
|
||||
|
||||
# Test creation
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_member = OrgMemberStore.add_user_to_org(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
llm_api_key='new-test-key',
|
||||
status='active',
|
||||
)
|
||||
|
||||
assert org_member is not None
|
||||
assert org_member.org_id == org_id
|
||||
assert org_member.user_id == user_id
|
||||
assert org_member.role_id == role_id
|
||||
assert org_member.llm_api_key.get_secret_value() == 'new-test-key'
|
||||
assert org_member.status == 'active'
|
||||
|
||||
|
||||
def test_update_user_role_in_org(session_maker):
|
||||
# Test updating user role in org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role1 = Role(name='admin', rank=1)
|
||||
role2 = Role(name='user', rank=2)
|
||||
session.add_all([user, role1, role2])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role1.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
role2_id = role2.id
|
||||
|
||||
# Test update
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
updated_org_member = OrgMemberStore.update_user_role_in_org(
|
||||
org_id=org_id, user_id=user_id, role_id=role2_id, status='inactive'
|
||||
)
|
||||
|
||||
assert updated_org_member is not None
|
||||
assert updated_org_member.role_id == role2_id
|
||||
assert updated_org_member.status == 'inactive'
|
||||
|
||||
|
||||
def test_update_user_role_in_org_not_found(session_maker):
|
||||
# Test updating org_member that doesn't exist
|
||||
from uuid import uuid4
|
||||
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
updated_org_member = OrgMemberStore.update_user_role_in_org(
|
||||
org_id=uuid4(), user_id=99999, role_id=1
|
||||
)
|
||||
assert updated_org_member is None
|
||||
|
||||
|
||||
def test_remove_user_from_org(session_maker):
|
||||
# Test removing a user from an org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
|
||||
# Test removal
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
result = OrgMemberStore.remove_user_from_org(org_id, user_id)
|
||||
assert result is True
|
||||
|
||||
# Verify it's removed
|
||||
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
|
||||
assert retrieved_org_member is None
|
||||
|
||||
|
||||
def test_remove_user_from_org_not_found(session_maker):
|
||||
# Test removing user from org that doesn't exist
|
||||
from uuid import uuid4
|
||||
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
result = OrgMemberStore.remove_user_from_org(uuid4(), 99999)
|
||||
assert result is False
|
||||
@@ -1,197 +0,0 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
# Mock the database module before importing OrgStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_api():
|
||||
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
|
||||
api_url_patch = patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
|
||||
)
|
||||
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
|
||||
client_patch = patch('httpx.AsyncClient')
|
||||
|
||||
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.patch.return_value = (
|
||||
mock_response
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_get_org_by_id(session_maker, mock_litellm_api):
|
||||
# Test getting org by ID
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_org_by_id(org_id)
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.id == org_id
|
||||
assert retrieved_org.name == 'test-org'
|
||||
|
||||
|
||||
def test_get_org_by_id_not_found(session_maker):
|
||||
# Test getting org by ID when it doesn't exist
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
non_existent_id = uuid.uuid4()
|
||||
retrieved_org = OrgStore.get_org_by_id(non_existent_id)
|
||||
assert retrieved_org is None
|
||||
|
||||
|
||||
def test_list_orgs(session_maker, mock_litellm_api):
|
||||
# Test listing all orgs
|
||||
with session_maker() as session:
|
||||
# Create test orgs
|
||||
org1 = Org(name='test-org-1')
|
||||
org2 = Org(name='test-org-2')
|
||||
session.add_all([org1, org2])
|
||||
session.commit()
|
||||
|
||||
# Test listing
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
orgs = OrgStore.list_orgs()
|
||||
assert len(orgs) >= 2
|
||||
org_names = [org.name for org in orgs]
|
||||
assert 'test-org-1' in org_names
|
||||
assert 'test-org-2' in org_names
|
||||
|
||||
|
||||
def test_update_org(session_maker, mock_litellm_api):
|
||||
# Test updating org details
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org', agent='CodeActAgent')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test update
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
updated_org = OrgStore.update_org(
|
||||
org_id=org_id, kwargs={'name': 'updated-org', 'agent': 'PlannerAgent'}
|
||||
)
|
||||
|
||||
assert updated_org is not None
|
||||
assert updated_org.name == 'updated-org'
|
||||
assert updated_org.agent == 'PlannerAgent'
|
||||
|
||||
|
||||
def test_update_org_not_found(session_maker):
|
||||
# Test updating org that doesn't exist
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
from uuid import uuid4
|
||||
|
||||
updated_org = OrgStore.update_org(
|
||||
org_id=uuid4(), kwargs={'name': 'updated-org'}
|
||||
)
|
||||
assert updated_org is None
|
||||
|
||||
|
||||
def test_create_org(session_maker, mock_litellm_api):
|
||||
# Test creating a new org
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
org = OrgStore.create_org(kwargs={'name': 'new-org', 'agent': 'CodeActAgent'})
|
||||
|
||||
assert org is not None
|
||||
assert org.name == 'new-org'
|
||||
assert org.agent == 'CodeActAgent'
|
||||
assert org.id is not None
|
||||
|
||||
|
||||
def test_get_org_by_name(session_maker, mock_litellm_api):
|
||||
# Test getting org by name
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org-by-name')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_org_by_name('test-org-by-name')
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.name == 'test-org-by-name'
|
||||
|
||||
|
||||
def test_get_current_org_from_keycloak_user_id(session_maker, mock_litellm_api):
|
||||
# Test getting current org from user ID
|
||||
test_user_id = uuid.uuid4()
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
from storage.user import User
|
||||
|
||||
user = User(id=test_user_id, current_org_id=org.id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_current_org_from_keycloak_user_id(
|
||||
str(test_user_id)
|
||||
)
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.name == 'test-org'
|
||||
|
||||
|
||||
def test_get_kwargs_from_settings():
|
||||
# Test extracting org kwargs from settings
|
||||
settings = Settings(
|
||||
language='es',
|
||||
agent='CodeActAgent',
|
||||
llm_model='gpt-4',
|
||||
llm_api_key=SecretStr('test-key'),
|
||||
enable_sound_notifications=True,
|
||||
)
|
||||
|
||||
kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
|
||||
# Should only include fields that exist in Org model
|
||||
assert 'agent' in kwargs
|
||||
assert 'default_llm_model' in kwargs
|
||||
assert kwargs['agent'] == 'CodeActAgent'
|
||||
assert kwargs['default_llm_model'] == 'gpt-4'
|
||||
# Should not include fields that don't exist in Org model
|
||||
assert 'language' not in kwargs # language is not in Org model
|
||||
assert 'llm_api_key' not in kwargs
|
||||
assert 'llm_model' not in kwargs
|
||||
assert 'enable_sound_notifications' not in kwargs
|
||||
@@ -1,15 +1,32 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from integrations.github.github_view import get_user_proactive_conversation_setting
|
||||
from storage.org import Org
|
||||
from integrations.github.github_view import get_user_proactive_conversation_setting
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
# Mock the call_sync_from_async function to return the result of the function directly
|
||||
def mock_call_sync_from_async(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
filter = MagicMock()
|
||||
|
||||
# Mock the context manager behavior
|
||||
session.__enter__.return_value = session
|
||||
|
||||
session.query.return_value = query
|
||||
query.filter.return_value = filter
|
||||
|
||||
return session, query, filter
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_no_user_id():
|
||||
"""Test that the function returns False when no user ID is provided."""
|
||||
with patch(
|
||||
@@ -25,82 +42,75 @@ async def test_get_user_proactive_conversation_setting_no_user_id():
|
||||
assert await get_user_proactive_conversation_setting(None) is False
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_not_found():
|
||||
async def test_get_user_proactive_conversation_setting_user_not_found(mock_session):
|
||||
"""Test that False is returned when the user is not found."""
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
session, query, filter = mock_session
|
||||
filter.first.return_value = None
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=None,
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_none():
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_none(mock_session):
|
||||
"""Test that False is returned when the user setting is None."""
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = None
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = None
|
||||
filter.first.return_value = user_settings
|
||||
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_true():
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_true(mock_session):
|
||||
"""Test that True is returned when the user setting is True and the global setting is True."""
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = True
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = True
|
||||
filter.first.return_value = user_settings
|
||||
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is True
|
||||
)
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is True
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_false():
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_false(mock_session):
|
||||
"""Test that False is returned when the user setting is False, regardless of global setting."""
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = False
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = False
|
||||
filter.first.return_value = user_settings
|
||||
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock the database module before importing RoleStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
def test_get_role_by_id(session_maker):
|
||||
# Test getting role by ID
|
||||
with session_maker() as session:
|
||||
# Create a test role
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_id(role_id)
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'admin'
|
||||
|
||||
|
||||
def test_get_role_by_id_not_found(session_maker):
|
||||
# Test getting role by ID when it doesn't exist
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_id(99999)
|
||||
assert retrieved_role is None
|
||||
|
||||
|
||||
def test_get_role_by_name(session_maker):
|
||||
# Test getting role by name
|
||||
with session_maker() as session:
|
||||
# Create a test role
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_name('admin')
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'admin'
|
||||
|
||||
|
||||
def test_get_role_by_name_not_found(session_maker):
|
||||
# Test getting role by name when it doesn't exist
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_name('nonexistent')
|
||||
assert retrieved_role is None
|
||||
|
||||
|
||||
def test_list_roles(session_maker):
|
||||
# Test listing all roles
|
||||
with session_maker() as session:
|
||||
# Create test roles
|
||||
role1 = Role(name='admin', rank=1)
|
||||
role2 = Role(name='user', rank=2)
|
||||
session.add_all([role1, role2])
|
||||
session.commit()
|
||||
|
||||
# Test listing
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
roles = RoleStore.list_roles()
|
||||
assert len(roles) >= 2
|
||||
role_names = [role.name for role in roles]
|
||||
assert 'admin' in role_names
|
||||
assert 'user' in role_names
|
||||
|
||||
|
||||
def test_create_role(session_maker):
|
||||
# Test creating a new role
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
role = RoleStore.create_role(name='moderator', rank=2)
|
||||
|
||||
assert role is not None
|
||||
assert role.name == 'moderator'
|
||||
assert role.rank == 2
|
||||
assert role.id is not None
|
||||
@@ -1,26 +1,11 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
# Import the actual StoredConversationMetadata from OpenHands core
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
|
||||
# Mock the lazy import to return the actual class
|
||||
with patch(
|
||||
'storage.stored_conversation_metadata.StoredConversationMetadata',
|
||||
StoredConversationMetadata,
|
||||
):
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_call_sync_from_async():
|
||||
@@ -35,25 +20,12 @@ def mock_call_sync_from_async():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_user_store():
|
||||
"""Mock UserStore.get_user_by_id to return a mock user"""
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.current_org_id = UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
with patch(
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id',
|
||||
return_value=mock_user,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_and_get(session_maker):
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='my-conversation-id',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
user_id='12345',
|
||||
selected_repository='my-repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime.now(UTC),
|
||||
@@ -75,13 +47,13 @@ async def test_save_and_get(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search(session_maker):
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
|
||||
# Create test conversations with different timestamps
|
||||
conversations = [
|
||||
ConversationMetadata(
|
||||
conversation_id=f'conv-{i}',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
user_id='12345',
|
||||
selected_repository='repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime(2024, 1, i + 1, tzinfo=UTC),
|
||||
@@ -120,10 +92,10 @@ async def test_search(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_metadata(session_maker):
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='to-delete',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
user_id='12345',
|
||||
selected_repository='repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime.now(UTC),
|
||||
@@ -140,17 +112,17 @@ async def test_delete_metadata(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_metadata(session_maker):
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await store.get_metadata('nonexistent-id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists(session_maker):
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='exists-test',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
user_id='12345',
|
||||
selected_repository='repo',
|
||||
selected_branch='test-branch',
|
||||
created_at=datetime.now(UTC),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
@@ -20,14 +19,6 @@ def mock_config():
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock user with org_id."""
|
||||
user = MagicMock()
|
||||
user.current_org_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def secrets_store(session_maker, mock_config):
|
||||
return SaasSecretsStore('user-id', session_maker, mock_config)
|
||||
@@ -35,11 +26,7 @@ def secrets_store(session_maker, mock_config):
|
||||
|
||||
class TestSaasSecretsStore:
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
async def test_store_and_load(self, mock_get_user, secrets_store, mock_user):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
async def test_store_and_load(self, secrets_store):
|
||||
# Create a Secrets object with some test data
|
||||
user_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
@@ -72,10 +59,7 @@ class TestSaasSecretsStore:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
async def test_encryption_decryption(self, mock_get_user, secrets_store, mock_user):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
async def test_encryption_decryption(self, secrets_store):
|
||||
# Create a Secrets object with sensitive data
|
||||
user_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
@@ -105,7 +89,6 @@ class TestSaasSecretsStore:
|
||||
stored = (
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == 'user-id')
|
||||
.filter(StoredCustomSecrets.org_id == mock_user.current_org_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -169,12 +152,7 @@ class TestSaasSecretsStore:
|
||||
assert await secrets_store.load() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
async def test_update_existing_secrets(
|
||||
self, mock_get_user, secrets_store, mock_user
|
||||
):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
async def test_update_existing_secrets(self, secrets_store):
|
||||
# Create and store initial secrets
|
||||
initial_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
|
||||
@@ -2,17 +2,65 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
from server.constants import (
|
||||
CURRENT_USER_SETTINGS_VERSION,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
)
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_get_response():
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'user_info': {}})
|
||||
return mock_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_post_response():
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
return mock_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_api(mock_litellm_get_response, mock_litellm_post_response):
|
||||
api_key_patch = patch('storage.saas_settings_store.LITE_LLM_API_KEY', 'test_key')
|
||||
api_url_patch = patch(
|
||||
'storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'
|
||||
)
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
team_id_patch = patch('storage.saas_settings_store.LITE_LLM_TEAM_ID', 'test_team')
|
||||
client_patch = patch('httpx.AsyncClient')
|
||||
|
||||
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_litellm_get_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_litellm_post_response
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stripe():
|
||||
search_patch = patch(
|
||||
'stripe.Customer.search_async',
|
||||
AsyncMock(return_value=MagicMock(id='mock-customer-id')),
|
||||
)
|
||||
payment_patch = patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[{}])),
|
||||
)
|
||||
with search_patch, payment_patch:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -35,42 +83,41 @@ def mock_config():
|
||||
|
||||
@pytest.fixture
|
||||
def settings_store(session_maker, mock_config):
|
||||
store = SaasSettingsStore(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker, mock_config
|
||||
)
|
||||
store = SaasSettingsStore('user-id', session_maker, mock_config)
|
||||
|
||||
# Patch the load method to read from UserSettings table directly (for testing)
|
||||
# Patch the store method directly to filter out email and email_verified
|
||||
original_load = store.load
|
||||
original_create_default = store.create_default_settings
|
||||
original_update_litellm = store.update_settings_with_litellm_default
|
||||
|
||||
# Patch the load method to add email and email_verified
|
||||
async def patched_load():
|
||||
with store.session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == store.user_id)
|
||||
.first()
|
||||
)
|
||||
if not user_settings:
|
||||
# Return default settings
|
||||
return Settings(
|
||||
llm_api_key=SecretStr('test_api_key'),
|
||||
llm_base_url='http://test.url',
|
||||
agent='CodeActAgent',
|
||||
language='en',
|
||||
)
|
||||
|
||||
# Decrypt and convert to Settings
|
||||
kwargs = {}
|
||||
for column in UserSettings.__table__.columns:
|
||||
if column.name != 'keycloak_user_id':
|
||||
value = getattr(user_settings, column.name, None)
|
||||
if value is not None:
|
||||
kwargs[column.name] = value
|
||||
|
||||
store._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
settings = await original_load()
|
||||
if settings:
|
||||
# Add email and email_verified fields to mimic SaasUserAuth behavior
|
||||
settings.email = 'test@example.com'
|
||||
settings.email_verified = True
|
||||
return settings
|
||||
return settings
|
||||
|
||||
# Patch the store method to write to UserSettings table directly (for testing)
|
||||
# Patch the create_default_settings method to add email and email_verified
|
||||
async def patched_create_default(settings):
|
||||
settings = await original_create_default(settings)
|
||||
if settings:
|
||||
# Add email and email_verified fields to mimic SaasUserAuth behavior
|
||||
settings.email = 'test@example.com'
|
||||
settings.email_verified = True
|
||||
return settings
|
||||
|
||||
# Patch the update_settings_with_litellm_default method
|
||||
async def patched_update_litellm(settings):
|
||||
updated_settings = await original_update_litellm(settings)
|
||||
if updated_settings:
|
||||
# Add email and email_verified fields to mimic SaasUserAuth behavior
|
||||
updated_settings.email = 'test@example.com'
|
||||
updated_settings.email_verified = True
|
||||
return updated_settings
|
||||
|
||||
# Patch the store method to filter out email and email_verified
|
||||
async def patched_store(item):
|
||||
if item:
|
||||
# Make a copy of the item without email and email_verified
|
||||
@@ -99,9 +146,11 @@ def settings_store(session_maker, mock_config):
|
||||
for key, value in item_dict.items():
|
||||
if key in existing.__class__.__table__.columns:
|
||||
setattr(existing, key, value)
|
||||
existing.user_version = CURRENT_USER_SETTINGS_VERSION
|
||||
session.merge(existing)
|
||||
else:
|
||||
item_dict['keycloak_user_id'] = store.user_id
|
||||
item_dict['user_version'] = CURRENT_USER_SETTINGS_VERSION
|
||||
settings = UserSettings(**item_dict)
|
||||
session.add(settings)
|
||||
session.commit()
|
||||
@@ -109,6 +158,8 @@ def settings_store(session_maker, mock_config):
|
||||
# Replace the methods with our patched versions
|
||||
store.store = patched_store
|
||||
store.load = patched_load
|
||||
store.create_default_settings = patched_create_default
|
||||
store.update_settings_with_litellm_default = patched_update_litellm
|
||||
return store
|
||||
|
||||
|
||||
@@ -146,11 +197,17 @@ async def test_store_and_load_keycloak_user(settings_store):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_returns_default_when_not_found(settings_store, session_maker):
|
||||
async def test_load_returns_default_when_not_found(
|
||||
settings_store, mock_litellm_api, mock_stripe, mock_github_user, session_maker
|
||||
):
|
||||
file_store = MagicMock()
|
||||
file_store.read.side_effect = FileNotFoundError()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.saas_settings_store.get_file_store',
|
||||
MagicMock(return_value=file_store),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
loaded_settings = await settings_store.load()
|
||||
@@ -161,9 +218,233 @@ async def test_load_returns_default_when_not_found(settings_store, session_maker
|
||||
assert loaded_settings.llm_base_url == 'http://test.url'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
settings = Settings()
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
settings = await settings_store.update_settings_with_litellm_default(settings)
|
||||
|
||||
assert settings.agent == 'CodeActAgent'
|
||||
assert settings.llm_api_key
|
||||
assert settings.llm_api_key.get_secret_value() == 'test_api_key'
|
||||
assert settings.llm_base_url == 'http://test.url'
|
||||
|
||||
# Get the actual call arguments
|
||||
call_args = mock_litellm_api.return_value.__aenter__.return_value.post.call_args[1]
|
||||
|
||||
# Check that the URL and most of the JSON payload match what we expect
|
||||
assert call_args['json']['user_email'] == 'testy@tester.com'
|
||||
assert call_args['json']['models'] == []
|
||||
assert call_args['json']['max_budget'] == 10.0
|
||||
assert call_args['json']['user_id'] == 'user-id'
|
||||
assert call_args['json']['teams'] == ['test_team']
|
||||
assert call_args['json']['auto_create_key'] is True
|
||||
assert call_args['json']['send_invite_email'] is False
|
||||
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
|
||||
assert 'model' in call_args['json']['metadata']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_no_user_id():
|
||||
store = SaasSettingsStore('', MagicMock(), MagicMock())
|
||||
settings = await store.create_default_settings(None)
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_require_payment_enabled(
|
||||
settings_store, mock_stripe
|
||||
):
|
||||
# Mock stripe_service.has_payment_method to return False
|
||||
with (
|
||||
patch('storage.saas_settings_store.REQUIRE_PAYMENT', True),
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
),
|
||||
patch(
|
||||
'integrations.stripe_service.session_maker', settings_store.session_maker
|
||||
),
|
||||
):
|
||||
settings = await settings_store.create_default_settings(None)
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_require_payment_disabled(
|
||||
settings_store, mock_stripe, mock_github_user, mock_litellm_api, session_maker
|
||||
):
|
||||
# Even without payment method, should get default settings when REQUIRE_PAYMENT is False
|
||||
file_store = MagicMock()
|
||||
file_store.read.side_effect = FileNotFoundError()
|
||||
with (
|
||||
patch('storage.saas_settings_store.REQUIRE_PAYMENT', False),
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
),
|
||||
patch(
|
||||
'storage.saas_settings_store.get_file_store',
|
||||
MagicMock(return_value=file_store),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
settings = await settings_store.create_default_settings(None)
|
||||
assert settings is not None
|
||||
assert settings.language == 'en'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_lite_llm_settings_no_api_config(settings_store):
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_KEY', None),
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', None),
|
||||
):
|
||||
settings = Settings()
|
||||
settings = await settings_store.update_settings_with_litellm_default(settings)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_error(settings_store):
|
||||
with patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'duplicate@example.com'}),
|
||||
):
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
AsyncMock(
|
||||
json=MagicMock(
|
||||
return_value={'user_info': {'max_budget': 10, 'spend': 5}}
|
||||
)
|
||||
)
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value.is_success = False
|
||||
settings = Settings()
|
||||
settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_retry_on_duplicate_email(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# First response is a delete and succeeds
|
||||
mock_delete_response = MagicMock()
|
||||
mock_delete_response.is_success = True
|
||||
mock_delete_response.status_code = 200
|
||||
|
||||
# Second response fails with duplicate email error
|
||||
mock_error_response = MagicMock()
|
||||
mock_error_response.is_success = False
|
||||
mock_error_response.status_code = 400
|
||||
mock_error_response.text = 'User with this email already exists'
|
||||
|
||||
# Thire response succeeds with no email
|
||||
mock_success_response = MagicMock()
|
||||
mock_success_response.is_success = True
|
||||
mock_success_response.json = MagicMock(return_value={'key': 'new_test_api_key'})
|
||||
|
||||
# Set up mocks
|
||||
post_mock = AsyncMock()
|
||||
post_mock.side_effect = [
|
||||
mock_delete_response,
|
||||
mock_error_response,
|
||||
mock_success_response,
|
||||
]
|
||||
mock_litellm_api.return_value.__aenter__.return_value.post = post_mock
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'duplicate@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
settings = Settings()
|
||||
settings = await settings_store.update_settings_with_litellm_default(settings)
|
||||
|
||||
assert settings is not None
|
||||
assert settings.llm_api_key
|
||||
assert settings.llm_api_key.get_secret_value() == 'new_test_api_key'
|
||||
|
||||
# Verify second call was with email
|
||||
second_call_args = post_mock.call_args_list[1][1]
|
||||
assert second_call_args['json']['user_email'] == 'duplicate@example.com'
|
||||
|
||||
# Verify third call was with None for email
|
||||
third_call_args = post_mock.call_args_list[2][1]
|
||||
assert third_call_args['json']['user_email'] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_in_lite_llm(settings_store):
|
||||
# Test the _create_user_in_lite_llm method directly
|
||||
mock_client = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Test with email
|
||||
await settings_store._create_user_in_lite_llm(
|
||||
mock_client, 'test@example.com', 50, 10
|
||||
)
|
||||
|
||||
# Get the actual call arguments
|
||||
call_args = mock_client.post.call_args[1]
|
||||
|
||||
# Check that the URL and most of the JSON payload match what we expect
|
||||
assert call_args['json']['user_email'] == 'test@example.com'
|
||||
assert call_args['json']['models'] == []
|
||||
assert call_args['json']['max_budget'] == 50
|
||||
assert call_args['json']['spend'] == 10
|
||||
assert call_args['json']['user_id'] == 'user-id'
|
||||
assert call_args['json']['teams'] == [LITE_LLM_TEAM_ID]
|
||||
assert call_args['json']['auto_create_key'] is True
|
||||
assert call_args['json']['send_invite_email'] is False
|
||||
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
|
||||
assert 'model' in call_args['json']['metadata']
|
||||
|
||||
# Test with None email
|
||||
mock_client.post.reset_mock()
|
||||
await settings_store._create_user_in_lite_llm(mock_client, None, 25, 15)
|
||||
|
||||
# Get the actual call arguments
|
||||
call_args = mock_client.post.call_args[1]
|
||||
|
||||
# Check that the URL and most of the JSON payload match what we expect
|
||||
assert call_args['json']['user_email'] is None
|
||||
assert call_args['json']['models'] == []
|
||||
assert call_args['json']['max_budget'] == 25
|
||||
assert call_args['json']['spend'] == 15
|
||||
assert call_args['json']['user_id'] == str(settings_store.user_id)
|
||||
assert call_args['json']['teams'] == [LITE_LLM_TEAM_ID]
|
||||
assert call_args['json']['auto_create_key'] is True
|
||||
assert call_args['json']['send_invite_email'] is False
|
||||
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
|
||||
assert 'model' in call_args['json']['metadata']
|
||||
|
||||
# Verify response is returned correctly
|
||||
assert (
|
||||
await settings_store._create_user_in_lite_llm(
|
||||
mock_client, 'email@test.com', 30, 7
|
||||
)
|
||||
== mock_response
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encryption(settings_store):
|
||||
settings_store.user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' # GitHub user ID
|
||||
settings_store.user_id = 'mock-id' # GitHub user ID
|
||||
settings = Settings(
|
||||
llm_api_key=SecretStr('secret_key'),
|
||||
agent='smith',
|
||||
@@ -175,9 +456,7 @@ async def test_encryption(settings_store):
|
||||
with settings_store.session_maker() as session:
|
||||
stored = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
.filter(UserSettings.keycloak_user_id == 'mock-id')
|
||||
.first()
|
||||
)
|
||||
# The stored key should be encrypted
|
||||
|
||||
@@ -3,30 +3,27 @@ This test file verifies that the stripe_service functions properly use the datab
|
||||
to store and retrieve customer IDs.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from integrations.stripe_service import (
|
||||
find_customer_id_by_user_id,
|
||||
find_or_create_customer_by_user_id,
|
||||
find_or_create_customer,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.stripe_customer import Base as StripeCustomerBase
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user import User
|
||||
from storage.user_settings import Base as UserBase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
# Create all tables using the unified Base
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
UserBase.metadata.create_all(engine)
|
||||
StripeCustomerBase.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@@ -35,158 +32,79 @@ def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_org_and_user(session_maker):
|
||||
"""Create a test org and user for use in tests."""
|
||||
test_user_id = uuid.uuid4()
|
||||
test_org_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
# Create role first
|
||||
role = Role(name='test-role', rank=1)
|
||||
session.add(role)
|
||||
session.flush()
|
||||
|
||||
# Create org
|
||||
org = Org(id=test_org_id, name='test-org', contact_email='testy@tester.com')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
# Create user with current_org_id
|
||||
user = User(id=test_user_id, current_org_id=test_org_id, role_id=role.id)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
# Create org member relationship
|
||||
org_member = OrgMember(
|
||||
org_id=test_org_id,
|
||||
user_id=test_user_id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
|
||||
return test_user_id, test_org_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_customer_id_by_user_id_checks_db_first(
|
||||
session_maker, test_org_and_user
|
||||
):
|
||||
async def test_find_customer_id_by_user_id_checks_db_first(session_maker):
|
||||
"""Test that find_customer_id_by_user_id checks the database first"""
|
||||
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
|
||||
# Set up the mock for the database query result
|
||||
with session_maker() as session:
|
||||
# Create stripe customer
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id=str(test_user_id),
|
||||
org_id=test_org_id,
|
||||
keycloak_user_id='test-user-id',
|
||||
stripe_customer_id='cus_test123',
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Create a mock org object to return from OrgStore
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = test_org_id
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
|
||||
with patch('integrations.stripe_service.session_maker', session_maker):
|
||||
# Call the function
|
||||
result = await find_customer_id_by_user_id(str(test_user_id))
|
||||
result = await find_customer_id_by_user_id('test-user-id')
|
||||
|
||||
# Verify the result
|
||||
assert result == 'cus_test123'
|
||||
|
||||
# Verify that call_sync_from_async was called with the correct function
|
||||
mock_call_sync.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_customer_id_by_user_id_falls_back_to_stripe(
|
||||
session_maker, test_org_and_user
|
||||
):
|
||||
async def test_find_customer_id_by_user_id_falls_back_to_stripe(session_maker):
|
||||
"""Test that find_customer_id_by_user_id falls back to Stripe if not found in the database"""
|
||||
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
|
||||
# Set up the mock for stripe.Customer.search_async
|
||||
mock_customer = stripe.Customer(id='cus_test123')
|
||||
mock_search = AsyncMock(return_value=MagicMock(data=[mock_customer]))
|
||||
|
||||
# Create a mock org object to return from OrgStore
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = test_org_id
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('stripe.Customer.search_async', mock_search),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
|
||||
# Call the function
|
||||
result = await find_customer_id_by_user_id(str(test_user_id))
|
||||
result = await find_customer_id_by_user_id('test-user-id')
|
||||
|
||||
# Verify the result
|
||||
assert result == 'cus_test123'
|
||||
|
||||
# Verify that Stripe was searched with the org_id
|
||||
# Verify that Stripe was searched
|
||||
mock_search.assert_called_once()
|
||||
assert (
|
||||
f"metadata['org_id']:'{str(test_org_id)}'" in mock_search.call_args[1]['query']
|
||||
)
|
||||
assert "metadata['user_id']:'test-user-id'" in mock_search.call_args[1]['query']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_customer_stores_id_in_db(session_maker, test_org_and_user):
|
||||
async def test_create_customer_stores_id_in_db(session_maker):
|
||||
"""Test that create_customer stores the customer ID in the database"""
|
||||
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
|
||||
# Set up the mock for stripe.Customer.search_async and create_async
|
||||
# Set up the mock for stripe.Customer.search_async
|
||||
mock_search = AsyncMock(return_value=MagicMock(data=[]))
|
||||
mock_create_async = AsyncMock(return_value=stripe.Customer(id='cus_test123'))
|
||||
|
||||
# Create a mock org object to return from OrgStore
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = test_org_id
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('stripe.Customer.search_async', mock_search),
|
||||
patch('stripe.Customer.create_async', mock_create_async),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
|
||||
# Call the function
|
||||
result = await find_or_create_customer_by_user_id(str(test_user_id))
|
||||
result = await find_or_create_customer('test-user-id')
|
||||
|
||||
# Verify the result
|
||||
assert result == {'customer_id': 'cus_test123', 'org_id': str(test_org_id)}
|
||||
assert result == 'cus_test123'
|
||||
|
||||
# Verify that the stripe customer was stored in the db
|
||||
with session_maker() as session:
|
||||
customer = session.query(StripeCustomer).first()
|
||||
assert customer.id > 0
|
||||
assert customer.keycloak_user_id == str(test_user_id)
|
||||
assert customer.org_id == test_org_id
|
||||
assert customer.keycloak_user_id == 'test-user-id'
|
||||
assert customer.stripe_customer_id == 'cus_test123'
|
||||
assert customer.created_at is not None
|
||||
assert customer.updated_at is not None
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
# Mock the database module before importing UserStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from sqlalchemy.orm import configure_mappers
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope='session')
|
||||
def load_all_models():
|
||||
configure_mappers() # fail fast if anything’s missing
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_api():
|
||||
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
|
||||
api_url_patch = patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
|
||||
)
|
||||
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
|
||||
client_patch = patch('httpx.AsyncClient')
|
||||
|
||||
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_response
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stripe():
|
||||
search_patch = patch(
|
||||
'stripe.Customer.search_async',
|
||||
AsyncMock(return_value=MagicMock(id='mock-customer-id')),
|
||||
)
|
||||
payment_patch = patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[{}])),
|
||||
)
|
||||
with search_patch, payment_patch:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_no_org_id():
|
||||
# Test UserStore.create_default_settings with empty org_id
|
||||
settings = await UserStore.create_default_settings('', 'test-user-id')
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_require_org(session_maker, mock_stripe):
|
||||
# Mock stripe_service.has_payment_method to return False
|
||||
with (
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
),
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
):
|
||||
settings = await UserStore.create_default_settings(
|
||||
'test-org-id', 'test-user-id'
|
||||
)
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_with_litellm(session_maker, mock_litellm_api):
|
||||
# Test that UserStore.create_default_settings works with LiteLLM
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.user_store.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'attributes': {'github_id': ['12345']}}),
|
||||
),
|
||||
):
|
||||
settings = await UserStore.create_default_settings(
|
||||
'test-org-id', 'test-user-id'
|
||||
)
|
||||
assert settings is not None
|
||||
assert settings.llm_api_key.get_secret_value() == 'test_api_key'
|
||||
assert settings.llm_base_url == 'http://test.url'
|
||||
assert settings.agent == 'CodeActAgent'
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Complex integration test with session isolation issues')
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user(session_maker, mock_litellm_api):
|
||||
# Test creating a new user - skipped due to complex session isolation issues
|
||||
pass
|
||||
|
||||
|
||||
def test_get_user_by_id(session_maker):
|
||||
# Test getting user by ID
|
||||
test_org_id = uuid.uuid4()
|
||||
test_user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
with session_maker() as session:
|
||||
# Create a test user
|
||||
user = User(id=uuid.UUID(test_user_id), current_org_id=test_org_id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
user_id = user.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
retrieved_user = UserStore.get_user_by_id(test_user_id)
|
||||
assert retrieved_user is not None
|
||||
assert retrieved_user.id == user_id
|
||||
|
||||
|
||||
def test_list_users(session_maker):
|
||||
# Test listing all users
|
||||
test_org_id1 = uuid.uuid4()
|
||||
test_org_id2 = uuid.uuid4()
|
||||
test_user_id1 = uuid.uuid4()
|
||||
test_user_id2 = uuid.uuid4()
|
||||
with session_maker() as session:
|
||||
# Create test users
|
||||
user1 = User(id=test_user_id1, current_org_id=test_org_id1)
|
||||
user2 = User(id=test_user_id2, current_org_id=test_org_id2)
|
||||
session.add_all([user1, user2])
|
||||
session.commit()
|
||||
|
||||
# Test listing
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
users = UserStore.list_users()
|
||||
assert len(users) >= 2
|
||||
user_ids = [user.id for user in users]
|
||||
assert test_user_id1 in user_ids
|
||||
assert test_user_id2 in user_ids
|
||||
|
||||
|
||||
def test_get_kwargs_from_settings():
|
||||
# Test extracting user kwargs from settings
|
||||
settings = Settings(
|
||||
language='es',
|
||||
enable_sound_notifications=True,
|
||||
llm_api_key=SecretStr('test-key'),
|
||||
)
|
||||
|
||||
kwargs = UserStore.get_kwargs_from_settings(settings)
|
||||
|
||||
# Should only include fields that exist in User model
|
||||
assert 'language' in kwargs
|
||||
assert 'enable_sound_notifications' in kwargs
|
||||
# Should not include fields that don't exist in User model
|
||||
assert 'llm_api_key' not in kwargs
|
||||
@@ -61,7 +61,7 @@ describe("ExpandableMessage", () => {
|
||||
expect(icon).toHaveClass("fill-success");
|
||||
});
|
||||
|
||||
it("should render with no icon for failed action messages", () => {
|
||||
it("should render with error icon for failed action messages", () => {
|
||||
renderWithProviders(
|
||||
<ExpandableMessage
|
||||
id="OBSERVATION_MESSAGE$RUN"
|
||||
@@ -75,7 +75,8 @@ describe("ExpandableMessage", () => {
|
||||
"div.flex.gap-2.items-center.justify-start",
|
||||
);
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
|
||||
const icon = screen.getByTestId("status-icon");
|
||||
expect(icon).toHaveClass("fill-danger");
|
||||
});
|
||||
|
||||
it("should render with neutral border and no icon for action messages without success prop", () => {
|
||||
|
||||
+1
-1
@@ -3,7 +3,7 @@ import { describe, expect, it, vi } from "vitest";
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { AnalyticsConsentFormModal } from "#/components/features/analytics/analytics-consent-form-modal";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
|
||||
describe("AnalyticsConsentFormModal", () => {
|
||||
it("should call saveUserSettings with consent", async () => {
|
||||
|
||||
@@ -3,7 +3,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { QueryClientProvider, QueryClient } from "@tanstack/react-query";
|
||||
import { createRoutesStub, Outlet } from "react-router";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import ConversationService from "#/api/conversation-service/conversation-service.api";
|
||||
import GitService from "#/api/git-service/git-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
@@ -404,7 +404,7 @@ describe("RepoConnector", () => {
|
||||
ConversationService,
|
||||
"createConversation",
|
||||
);
|
||||
createConversationSpy.mockImplementation(() => new Promise(() => { })); // Never resolves to keep loading state
|
||||
createConversationSpy.mockImplementation(() => new Promise(() => {})); // Never resolves to keep loading state
|
||||
const retrieveUserGitRepositoriesSpy = vi.spyOn(
|
||||
GitService,
|
||||
"retrieveUserGitRepositories",
|
||||
|
||||
@@ -3,7 +3,7 @@ import { renderWithProviders } from "test-utils";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { waitFor } from "@testing-library/react";
|
||||
import { Sidebar } from "#/components/features/sidebar/sidebar";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
|
||||
// These tests will now fail because the conversation panel is rendered through a portal
|
||||
// and technically not a child of the Sidebar component.
|
||||
|
||||
@@ -57,7 +57,7 @@ describe("MicroagentsModal - Refresh Button", () => {
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Refresh Button Rendering", () => {
|
||||
@@ -74,15 +74,13 @@ describe("MicroagentsModal - Refresh Button", () => {
|
||||
describe("Refresh Button Functionality", () => {
|
||||
it("should call refetch when refresh button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const refreshSpy = vi.spyOn(ConversationService, "getMicroagents");
|
||||
|
||||
renderWithProviders(<MicroagentsModal {...defaultProps} />);
|
||||
|
||||
const refreshSpy = vi.spyOn(ConversationService, "getMicroagents");
|
||||
|
||||
// Wait for the component to load and render the refresh button
|
||||
const refreshButton = await screen.findByTestId("refresh-microagents");
|
||||
|
||||
refreshSpy.mockClear();
|
||||
|
||||
await user.click(refreshButton);
|
||||
|
||||
expect(refreshSpy).toHaveBeenCalledTimes(1);
|
||||
|
||||
@@ -3,7 +3,7 @@ import { describe, expect, it, vi } from "vitest";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { screen } from "@testing-library/react";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import { SettingsForm } from "#/components/shared/modals/settings/settings-form";
|
||||
import { DEFAULT_SETTINGS } from "#/services/settings";
|
||||
|
||||
|
||||
@@ -1,26 +1,12 @@
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeAll,
|
||||
beforeEach,
|
||||
afterAll,
|
||||
afterEach,
|
||||
} from "vitest";
|
||||
import { describe, it, expect, beforeAll, afterAll, afterEach } from "vitest";
|
||||
import { screen, waitFor, render, cleanup } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { useOptimisticUserMessageStore } from "#/stores/optimistic-user-message-store";
|
||||
import { useBrowserStore } from "#/stores/browser-store";
|
||||
import { useCommandStore } from "#/state/command-store";
|
||||
import {
|
||||
createMockMessageEvent,
|
||||
createMockUserMessageEvent,
|
||||
createMockAgentErrorEvent,
|
||||
createMockBrowserObservationEvent,
|
||||
createMockBrowserNavigateActionEvent,
|
||||
createMockExecuteBashActionEvent,
|
||||
createMockExecuteBashObservationEvent,
|
||||
} from "#/mocks/mock-ws-helpers";
|
||||
import {
|
||||
ConnectionStatusComponent,
|
||||
@@ -475,7 +461,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
);
|
||||
|
||||
// Create a test component that displays loading state
|
||||
function HistoryLoadingComponent() {
|
||||
const HistoryLoadingComponent = () => {
|
||||
const context = useConversationWebSocket();
|
||||
const { events } = useEventStore();
|
||||
|
||||
@@ -488,7 +474,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
<div data-testid="expected-event-count">{expectedEventCount}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(
|
||||
@@ -498,9 +484,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
);
|
||||
|
||||
// Initially should be loading history
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
|
||||
"true",
|
||||
);
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent("true");
|
||||
|
||||
// Wait for all events to be received
|
||||
await waitFor(() => {
|
||||
@@ -539,7 +523,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
);
|
||||
|
||||
// Create a test component that displays loading state
|
||||
function HistoryLoadingComponent() {
|
||||
const HistoryLoadingComponent = () => {
|
||||
const context = useConversationWebSocket();
|
||||
|
||||
return (
|
||||
@@ -549,7 +533,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(
|
||||
@@ -599,7 +583,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
);
|
||||
|
||||
// Create a test component that displays loading state
|
||||
function HistoryLoadingComponent() {
|
||||
const HistoryLoadingComponent = () => {
|
||||
const context = useConversationWebSocket();
|
||||
const { events } = useEventStore();
|
||||
|
||||
@@ -611,7 +595,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
<div data-testid="events-received">{events.length}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(
|
||||
@@ -621,9 +605,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
);
|
||||
|
||||
// Initially should be loading history
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
|
||||
"true",
|
||||
);
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent("true");
|
||||
|
||||
// Wait for all events to be received
|
||||
await waitFor(() => {
|
||||
@@ -639,133 +621,17 @@ describe("Conversation WebSocket Handler", () => {
|
||||
});
|
||||
});
|
||||
|
||||
// 9. Browser State Tests (BrowserObservation)
|
||||
describe("Browser State Integration", () => {
|
||||
beforeEach(() => {
|
||||
useBrowserStore.getState().reset();
|
||||
});
|
||||
|
||||
it("should update browser store with screenshot when BrowserObservation event is received", async () => {
|
||||
// Create a mock BrowserObservation event with screenshot data
|
||||
const mockBrowserObsEvent = createMockBrowserObservationEvent(
|
||||
"base64-screenshot-data",
|
||||
"Page loaded successfully",
|
||||
);
|
||||
|
||||
// Set up MSW to send the event when connection is established
|
||||
mswServer.use(
|
||||
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||
server.connect();
|
||||
// Send the mock event after connection
|
||||
client.send(JSON.stringify(mockBrowserObsEvent));
|
||||
}),
|
||||
);
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(<ConnectionStatusComponent />);
|
||||
|
||||
// Wait for connection
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||
"OPEN",
|
||||
);
|
||||
});
|
||||
|
||||
// Wait for the browser store to be updated with screenshot
|
||||
await waitFor(() => {
|
||||
const { screenshotSrc } = useBrowserStore.getState();
|
||||
expect(screenshotSrc).toBe(
|
||||
"data:image/png;base64,base64-screenshot-data",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should update browser store with URL when BrowserNavigateAction followed by BrowserObservation", async () => {
|
||||
// Create mock events - action first, then observation
|
||||
const mockBrowserActionEvent = createMockBrowserNavigateActionEvent(
|
||||
"https://example.com/test-page",
|
||||
);
|
||||
const mockBrowserObsEvent = createMockBrowserObservationEvent(
|
||||
"base64-screenshot-data",
|
||||
"Page loaded successfully",
|
||||
);
|
||||
|
||||
// Set up MSW to send both events when connection is established
|
||||
mswServer.use(
|
||||
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||
server.connect();
|
||||
// Send action first, then observation
|
||||
client.send(JSON.stringify(mockBrowserActionEvent));
|
||||
client.send(JSON.stringify(mockBrowserObsEvent));
|
||||
}),
|
||||
);
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(<ConnectionStatusComponent />);
|
||||
|
||||
// Wait for connection
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||
"OPEN",
|
||||
);
|
||||
});
|
||||
|
||||
// Wait for the browser store to be updated with both screenshot and URL
|
||||
await waitFor(() => {
|
||||
const { screenshotSrc, url } = useBrowserStore.getState();
|
||||
expect(screenshotSrc).toBe(
|
||||
"data:image/png;base64,base64-screenshot-data",
|
||||
);
|
||||
expect(url).toBe("https://example.com/test-page");
|
||||
});
|
||||
});
|
||||
|
||||
it("should not update browser store when BrowserObservation has no screenshot data", async () => {
|
||||
const initialScreenshot = useBrowserStore.getState().screenshotSrc;
|
||||
|
||||
// Create a mock BrowserObservation event WITHOUT screenshot data
|
||||
const mockBrowserObsEvent = createMockBrowserObservationEvent(
|
||||
null, // no screenshot
|
||||
"Browser action completed",
|
||||
);
|
||||
|
||||
// Set up MSW to send the event when connection is established
|
||||
mswServer.use(
|
||||
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||
server.connect();
|
||||
// Send the mock event after connection
|
||||
client.send(JSON.stringify(mockBrowserObsEvent));
|
||||
}),
|
||||
);
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(<ConnectionStatusComponent />);
|
||||
|
||||
// Wait for connection
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||
"OPEN",
|
||||
);
|
||||
});
|
||||
|
||||
// Give some time for any potential updates
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, 100);
|
||||
});
|
||||
|
||||
// Screenshot should remain unchanged (empty/initial value)
|
||||
const { screenshotSrc } = useBrowserStore.getState();
|
||||
expect(screenshotSrc).toBe(initialScreenshot);
|
||||
});
|
||||
});
|
||||
|
||||
// 10. Terminal I/O Tests (ExecuteBashAction and ExecuteBashObservation)
|
||||
// 9. Terminal I/O Tests (ExecuteBashAction and ExecuteBashObservation)
|
||||
describe("Terminal I/O Integration", () => {
|
||||
beforeEach(() => {
|
||||
useCommandStore.getState().clearTerminal();
|
||||
});
|
||||
|
||||
it("should append command to store when ExecuteBashAction event is received", async () => {
|
||||
const { createMockExecuteBashActionEvent } = await import(
|
||||
"#/mocks/mock-ws-helpers"
|
||||
);
|
||||
const { useCommandStore } = await import("#/state/command-store");
|
||||
|
||||
// Clear the command store before test
|
||||
useCommandStore.getState().clearTerminal();
|
||||
|
||||
// Create a mock ExecuteBashAction event
|
||||
const mockBashActionEvent = createMockExecuteBashActionEvent("npm test");
|
||||
|
||||
@@ -801,6 +667,14 @@ describe("Conversation WebSocket Handler", () => {
|
||||
});
|
||||
|
||||
it("should append output to store when ExecuteBashObservation event is received", async () => {
|
||||
const { createMockExecuteBashObservationEvent } = await import(
|
||||
"#/mocks/mock-ws-helpers"
|
||||
);
|
||||
const { useCommandStore } = await import("#/state/command-store");
|
||||
|
||||
// Clear the command store before test
|
||||
useCommandStore.getState().clearTerminal();
|
||||
|
||||
// Create a mock ExecuteBashObservation event
|
||||
const mockBashObservationEvent = createMockExecuteBashObservationEvent(
|
||||
"PASS tests/example.test.js\n ✓ should work (2 ms)",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import { useSaveSettings } from "#/hooks/mutation/use-save-settings";
|
||||
|
||||
describe("useSaveSettings", () => {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
/* eslint-disable max-classes-per-file */
|
||||
import { beforeAll, describe, expect, it, vi, afterEach } from "vitest";
|
||||
import { useTerminal } from "#/hooks/use-terminal";
|
||||
import { Command, useCommandStore } from "#/state/command-store";
|
||||
@@ -46,29 +45,17 @@ describe("useTerminal", () => {
|
||||
}));
|
||||
|
||||
beforeAll(() => {
|
||||
// mock ResizeObserver - use class for Vitest 4 constructor support
|
||||
window.ResizeObserver = class {
|
||||
observe = vi.fn();
|
||||
// mock ResizeObserver
|
||||
window.ResizeObserver = vi.fn().mockImplementation(() => ({
|
||||
observe: vi.fn(),
|
||||
unobserve: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
}));
|
||||
|
||||
unobserve = vi.fn();
|
||||
|
||||
disconnect = vi.fn();
|
||||
} as unknown as typeof ResizeObserver;
|
||||
|
||||
// mock Terminal - use class for Vitest 4 constructor support
|
||||
// mock Terminal
|
||||
vi.mock("@xterm/xterm", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("@xterm/xterm")>()),
|
||||
Terminal: class {
|
||||
loadAddon = mockTerminal.loadAddon;
|
||||
|
||||
open = mockTerminal.open;
|
||||
|
||||
write = mockTerminal.write;
|
||||
|
||||
writeln = mockTerminal.writeln;
|
||||
|
||||
dispose = mockTerminal.dispose;
|
||||
},
|
||||
Terminal: vi.fn().mockImplementation(() => mockTerminal),
|
||||
}));
|
||||
});
|
||||
|
||||
|
||||
@@ -1,11 +1,3 @@
|
||||
/**
|
||||
* TODO: Fix flaky WebSocket tests (https://github.com/OpenHands/OpenHands/issues/11944)
|
||||
*
|
||||
* Several tests in this file are skipped because they fail intermittently in CI
|
||||
* but pass locally. The SUSPECTED root cause is that `wsLink.broadcast()` sends messages
|
||||
* to ALL connected clients across all tests, causing cross-test contamination
|
||||
* when tests run in parallel with Vitest v4.
|
||||
*/
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import {
|
||||
describe,
|
||||
@@ -59,7 +51,7 @@ describe("useWebSocket", () => {
|
||||
expect(result.current.socket).toBeTruthy();
|
||||
});
|
||||
|
||||
it.skip("should handle incoming messages correctly", async () => {
|
||||
it("should handle incoming messages correctly", async () => {
|
||||
const { result } = renderHook(() => useWebSocket("ws://acme.com/ws"));
|
||||
|
||||
// Wait for connection to be established
|
||||
@@ -122,7 +114,7 @@ describe("useWebSocket", () => {
|
||||
expect(result.current.socket).toBeTruthy();
|
||||
});
|
||||
|
||||
it.skip("should close the WebSocket connection on unmount", async () => {
|
||||
it("should close the WebSocket connection on unmount", async () => {
|
||||
const { result, unmount } = renderHook(() =>
|
||||
useWebSocket("ws://acme.com/ws"),
|
||||
);
|
||||
@@ -212,7 +204,7 @@ describe("useWebSocket", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it.skip("should call onMessage handler when WebSocket receives a message", async () => {
|
||||
it("should call onMessage handler when WebSocket receives a message", async () => {
|
||||
const onMessageSpy = vi.fn();
|
||||
const options = { onMessage: onMessageSpy };
|
||||
|
||||
@@ -279,7 +271,7 @@ describe("useWebSocket", () => {
|
||||
expect(onErrorSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it.skip("should provide sendMessage function to send messages to WebSocket", async () => {
|
||||
it("should provide sendMessage function to send messages to WebSocket", async () => {
|
||||
const { result } = renderHook(() => useWebSocket("ws://acme.com/ws"));
|
||||
|
||||
// Wait for connection to be established
|
||||
|
||||
@@ -10,7 +10,7 @@ import MainApp from "#/routes/root-layout";
|
||||
import i18n from "#/i18n";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import * as CaptureConsent from "#/utils/handle-capture-consent";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import * as ToastHandlers from "#/utils/custom-toast-handlers";
|
||||
|
||||
describe("frontend/routes/_oh", () => {
|
||||
|
||||
@@ -3,7 +3,7 @@ import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import AppSettingsScreen from "#/routes/app-settings";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { AvailableLanguages } from "#/i18n";
|
||||
import * as CaptureConsent from "#/utils/handle-capture-consent";
|
||||
|
||||
@@ -6,7 +6,7 @@ import userEvent from "@testing-library/user-event";
|
||||
import i18next from "i18next";
|
||||
import { I18nextProvider } from "react-i18next";
|
||||
import GitSettingsScreen from "#/routes/git-settings";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import AuthService from "#/api/auth-service/auth-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
|
||||
@@ -6,7 +6,7 @@ import { createRoutesStub } from "react-router";
|
||||
import { createAxiosNotFoundErrorObject } from "test-utils";
|
||||
import HomeScreen from "#/routes/home";
|
||||
import { GitRepository } from "#/types/git";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import GitService from "#/api/git-service/git-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import MainApp from "#/routes/root-layout";
|
||||
|
||||
@@ -3,14 +3,13 @@ import userEvent from "@testing-library/user-event";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClientProvider, QueryClient } from "@tanstack/react-query";
|
||||
import LlmSettingsScreen from "#/routes/llm-settings";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import {
|
||||
MOCK_DEFAULT_USER_SETTINGS,
|
||||
resetTestHandlersMockSettings,
|
||||
} from "#/mocks/handlers";
|
||||
import * as AdvancedSettingsUtlls from "#/utils/has-advanced-settings-set";
|
||||
import * as ToastHandlers from "#/utils/custom-toast-handlers";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
|
||||
// Mock react-router hooks
|
||||
const mockUseSearchParams = vi.fn();
|
||||
@@ -256,210 +255,6 @@ describe("Content", () => {
|
||||
});
|
||||
|
||||
it.todo("should render an indicator if the llm api key is set");
|
||||
|
||||
describe("API key visibility in Basic Settings", () => {
|
||||
it("should hide API key input when SaaS mode is enabled and OpenHands provider is selected", async () => {
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return APP_MODE for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
const basicForm = screen.getByTestId("llm-settings-form-basic");
|
||||
const provider = within(basicForm).getByTestId("llm-provider-input");
|
||||
|
||||
// Verify OpenHands is selected by default
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenHands");
|
||||
});
|
||||
|
||||
// API key input should not be visible when OpenHands provider is selected in SaaS mode
|
||||
expect(
|
||||
within(basicForm).queryByTestId("llm-api-key-input"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(
|
||||
within(basicForm).queryByTestId("llm-api-key-help-anchor"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show API key input when SaaS mode is enabled and non-OpenHands provider is selected", async () => {
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return APP_MODE for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
const basicForm = screen.getByTestId("llm-settings-form-basic");
|
||||
const provider = within(basicForm).getByTestId("llm-provider-input");
|
||||
|
||||
// Select OpenAI provider
|
||||
await userEvent.click(provider);
|
||||
const providerOption = screen.getByText("OpenAI");
|
||||
await userEvent.click(providerOption);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenAI");
|
||||
});
|
||||
|
||||
// API key input should be visible when non-OpenHands provider is selected in SaaS mode
|
||||
expect(
|
||||
within(basicForm).getByTestId("llm-api-key-input"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
within(basicForm).getByTestId("llm-api-key-help-anchor"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show API key input when OSS mode is enabled and OpenHands provider is selected", async () => {
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return APP_MODE for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "oss",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
const basicForm = screen.getByTestId("llm-settings-form-basic");
|
||||
const provider = within(basicForm).getByTestId("llm-provider-input");
|
||||
|
||||
// Verify OpenHands is selected by default
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenHands");
|
||||
});
|
||||
|
||||
// API key input should be visible when OSS mode is enabled (even with OpenHands provider)
|
||||
expect(
|
||||
within(basicForm).getByTestId("llm-api-key-input"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
within(basicForm).getByTestId("llm-api-key-help-anchor"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show API key input when OSS mode is enabled and non-OpenHands provider is selected", async () => {
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return APP_MODE for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "oss",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
const basicForm = screen.getByTestId("llm-settings-form-basic");
|
||||
const provider = within(basicForm).getByTestId("llm-provider-input");
|
||||
|
||||
// Select OpenAI provider
|
||||
await userEvent.click(provider);
|
||||
const providerOption = screen.getByText("OpenAI");
|
||||
await userEvent.click(providerOption);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenAI");
|
||||
});
|
||||
|
||||
// API key input should be visible when OSS mode is enabled
|
||||
expect(
|
||||
within(basicForm).getByTestId("llm-api-key-input"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
within(basicForm).getByTestId("llm-api-key-help-anchor"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should hide API key input when switching from non-OpenHands to OpenHands provider in SaaS mode", async () => {
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return APP_MODE for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
const basicForm = screen.getByTestId("llm-settings-form-basic");
|
||||
const provider = within(basicForm).getByTestId("llm-provider-input");
|
||||
|
||||
// Start with OpenAI provider
|
||||
await userEvent.click(provider);
|
||||
const openAIOption = screen.getByText("OpenAI");
|
||||
await userEvent.click(openAIOption);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenAI");
|
||||
});
|
||||
|
||||
// API key input should be visible with OpenAI
|
||||
expect(
|
||||
within(basicForm).getByTestId("llm-api-key-input"),
|
||||
).toBeInTheDocument();
|
||||
|
||||
// Switch to OpenHands provider
|
||||
await userEvent.click(provider);
|
||||
const openHandsOption = screen.getByText("OpenHands");
|
||||
await userEvent.click(openHandsOption);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenHands");
|
||||
});
|
||||
|
||||
// API key input should now be hidden
|
||||
expect(
|
||||
within(basicForm).queryByTestId("llm-api-key-input"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(
|
||||
within(basicForm).queryByTestId("llm-api-key-help-anchor"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show API key input when switching from OpenHands to non-OpenHands provider in SaaS mode", async () => {
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return APP_MODE for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
const basicForm = screen.getByTestId("llm-settings-form-basic");
|
||||
const provider = within(basicForm).getByTestId("llm-provider-input");
|
||||
|
||||
// Verify OpenHands is selected by default
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenHands");
|
||||
});
|
||||
|
||||
// API key input should be hidden with OpenHands
|
||||
expect(
|
||||
within(basicForm).queryByTestId("llm-api-key-input"),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
// Switch to OpenAI provider
|
||||
await userEvent.click(provider);
|
||||
const openAIOption = screen.getByText("OpenAI");
|
||||
await userEvent.click(openAIOption);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenAI");
|
||||
});
|
||||
|
||||
// API key input should now be visible
|
||||
expect(
|
||||
within(basicForm).getByTestId("llm-api-key-input"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
within(basicForm).getByTestId("llm-api-key-help-anchor"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Form submission", () => {
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { render, screen, waitFor, within } from "@testing-library/react";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { createRoutesStub, Outlet } from "react-router";
|
||||
import SecretsSettingsScreen from "#/routes/secrets-settings";
|
||||
import { SecretsService } from "#/api/secrets-service";
|
||||
import { GetSecretsResponse } from "#/api/secrets-service.types";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
|
||||
@@ -21,25 +21,25 @@ const MOCK_GET_SECRETS_RESPONSE: GetSecretsResponse["custom_secrets"] = [
|
||||
},
|
||||
];
|
||||
|
||||
const renderSecretsSettings = () => {
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: () => <Outlet />,
|
||||
path: "/settings",
|
||||
children: [
|
||||
{
|
||||
Component: SecretsSettingsScreen,
|
||||
path: "/settings/secrets",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="git-settings-screen" />,
|
||||
path: "/settings/integrations",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: () => <Outlet />,
|
||||
path: "/settings",
|
||||
children: [
|
||||
{
|
||||
Component: SecretsSettingsScreen,
|
||||
path: "/settings/secrets",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="git-settings-screen" />,
|
||||
path: "/settings/integrations",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
return render(<RouterStub initialEntries={["/settings/secrets"]} />, {
|
||||
const renderSecretsSettings = () =>
|
||||
render(<RouterStub initialEntries={["/settings/secrets"]} />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
@@ -52,7 +52,6 @@ const renderSecretsSettings = () => {
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
@@ -62,10 +61,6 @@ beforeEach(() => {
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("Content", () => {
|
||||
it("should render the secrets settings screen", () => {
|
||||
renderSecretsSettings();
|
||||
@@ -506,8 +501,6 @@ describe("Secret actions", () => {
|
||||
|
||||
it("should not submit whitespace secret names or values", async () => {
|
||||
const createSecretSpy = vi.spyOn(SecretsService, "createSecret");
|
||||
const getSecretsSpy = vi.spyOn(SecretsService, "getSecrets");
|
||||
getSecretsSpy.mockResolvedValue([]);
|
||||
renderSecretsSettings();
|
||||
|
||||
// render form & hide items
|
||||
@@ -539,11 +532,9 @@ describe("Secret actions", () => {
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
expect(createSecretSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByText("SECRETS$SECRET_VALUE_REQUIRED"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
expect(
|
||||
screen.queryByText("SECRETS$SECRET_VALUE_REQUIRED"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not reset ipout values on an invalid submit", async () => {
|
||||
|
||||
Generated
+2690
-1606
File diff suppressed because it is too large
Load Diff
+42
-42
@@ -8,56 +8,56 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@heroui/react": "2.8.5",
|
||||
"@heroui/use-infinite-scroll": "^2.2.12",
|
||||
"@heroui/use-infinite-scroll": "^2.2.11",
|
||||
"@microlink/react-json-view": "^1.26.2",
|
||||
"@monaco-editor/react": "^4.7.0-rc.0",
|
||||
"@posthog/react": "^1.5.2",
|
||||
"@react-router/node": "^7.10.1",
|
||||
"@react-router/serve": "^7.10.1",
|
||||
"@posthog/react": "^1.4.0",
|
||||
"@react-router/node": "^7.9.3",
|
||||
"@react-router/serve": "^7.9.3",
|
||||
"@react-types/shared": "^3.32.0",
|
||||
"@stripe/react-stripe-js": "^5.4.1",
|
||||
"@stripe/stripe-js": "^8.5.3",
|
||||
"@tailwindcss/postcss": "^4.1.17",
|
||||
"@tailwindcss/vite": "^4.1.17",
|
||||
"@tanstack/react-query": "^5.90.12",
|
||||
"@stripe/react-stripe-js": "^4.0.2",
|
||||
"@stripe/stripe-js": "^7.9.0",
|
||||
"@tailwindcss/postcss": "^4.1.13",
|
||||
"@tailwindcss/vite": "^4.1.13",
|
||||
"@tanstack/react-query": "^5.90.2",
|
||||
"@uidotdev/usehooks": "^2.4.1",
|
||||
"@vitejs/plugin-react": "^5.1.2",
|
||||
"@vitejs/plugin-react": "^5.0.4",
|
||||
"@xterm/addon-fit": "^0.10.0",
|
||||
"@xterm/xterm": "^5.4.0",
|
||||
"axios": "^1.13.2",
|
||||
"axios": "^1.12.2",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"date-fns": "^4.1.0",
|
||||
"downshift": "^9.0.12",
|
||||
"downshift": "^9.0.10",
|
||||
"eslint-config-airbnb-typescript": "^18.0.0",
|
||||
"framer-motion": "^12.23.25",
|
||||
"i18next": "^25.7.2",
|
||||
"framer-motion": "^12.23.22",
|
||||
"i18next": "^25.5.2",
|
||||
"i18next-browser-languagedetector": "^8.2.0",
|
||||
"i18next-http-backend": "^3.0.2",
|
||||
"isbot": "^5.1.32",
|
||||
"jose": "^6.1.3",
|
||||
"lucide-react": "^0.556.0",
|
||||
"monaco-editor": "^0.55.1",
|
||||
"posthog-js": "^1.302.2",
|
||||
"react": "^19.2.0",
|
||||
"react-dom": "^19.2.0",
|
||||
"isbot": "^5.1.31",
|
||||
"jose": "^6.1.0",
|
||||
"lucide-react": "^0.544.0",
|
||||
"monaco-editor": "^0.53.0",
|
||||
"posthog-js": "^1.298.1",
|
||||
"react": "^19.1.1",
|
||||
"react-dom": "^19.1.1",
|
||||
"react-highlight": "^0.15.0",
|
||||
"react-hot-toast": "^2.6.0",
|
||||
"react-i18next": "^16.4.0",
|
||||
"react-i18next": "^16.0.0",
|
||||
"react-icons": "^5.5.0",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-router": "^7.10.1",
|
||||
"react-syntax-highlighter": "^16.1.0",
|
||||
"react-router": "^7.9.3",
|
||||
"react-syntax-highlighter": "^15.6.6",
|
||||
"remark-breaks": "^4.0.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"sirv-cli": "^3.0.1",
|
||||
"socket.io-client": "^4.8.1",
|
||||
"tailwind-merge": "^3.4.0",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"tailwind-scrollbar": "^4.0.2",
|
||||
"vite": "^7.2.7",
|
||||
"vite": "^7.1.7",
|
||||
"web-vitals": "^5.1.0",
|
||||
"ws": "^8.18.2",
|
||||
"zustand": "^5.0.9"
|
||||
"zustand": "^5.0.8"
|
||||
},
|
||||
"scripts": {
|
||||
"dev": "npm run make-i18n && cross-env VITE_MOCK_API=false react-router dev",
|
||||
@@ -96,25 +96,25 @@
|
||||
"@babel/traverse": "^7.28.3",
|
||||
"@babel/types": "^7.28.2",
|
||||
"@mswjs/socket.io-binding": "^0.2.0",
|
||||
"@playwright/test": "^1.57.0",
|
||||
"@react-router/dev": "^7.10.1",
|
||||
"@playwright/test": "^1.55.1",
|
||||
"@react-router/dev": "^7.9.3",
|
||||
"@tailwindcss/typography": "^0.5.19",
|
||||
"@tanstack/eslint-plugin-query": "^5.91.0",
|
||||
"@testing-library/dom": "^10.4.1",
|
||||
"@testing-library/jest-dom": "^6.9.1",
|
||||
"@testing-library/jest-dom": "^6.8.0",
|
||||
"@testing-library/react": "^16.3.0",
|
||||
"@testing-library/user-event": "^14.6.1",
|
||||
"@types/node": "^24.10.1",
|
||||
"@types/react": "^19.2.7",
|
||||
"@types/react-dom": "^19.2.3",
|
||||
"@types/node": "^24.5.2",
|
||||
"@types/react": "^19.1.15",
|
||||
"@types/react-dom": "^19.1.9",
|
||||
"@types/react-highlight": "^0.12.8",
|
||||
"@types/react-syntax-highlighter": "^15.5.13",
|
||||
"@types/ws": "^8.18.1",
|
||||
"@typescript-eslint/eslint-plugin": "^7.18.0",
|
||||
"@typescript-eslint/parser": "^7.18.0",
|
||||
"@vitest/coverage-v8": "^4.0.14",
|
||||
"autoprefixer": "^10.4.22",
|
||||
"cross-env": "^10.1.0",
|
||||
"@vitest/coverage-v8": "^3.2.3",
|
||||
"autoprefixer": "^10.4.21",
|
||||
"cross-env": "^10.0.0",
|
||||
"eslint": "^8.57.0",
|
||||
"eslint-config-airbnb": "^19.0.4",
|
||||
"eslint-config-airbnb-typescript": "^18.0.0",
|
||||
@@ -127,16 +127,16 @@
|
||||
"eslint-plugin-react-hooks": "^4.6.2",
|
||||
"eslint-plugin-unused-imports": "^4.2.0",
|
||||
"husky": "^9.1.7",
|
||||
"jsdom": "^27.3.0",
|
||||
"lint-staged": "^16.2.7",
|
||||
"jsdom": "^27.0.0",
|
||||
"lint-staged": "^16.2.3",
|
||||
"msw": "^2.6.6",
|
||||
"prettier": "^3.7.3",
|
||||
"stripe": "^20.0.0",
|
||||
"prettier": "^3.6.2",
|
||||
"stripe": "^18.5.0",
|
||||
"tailwindcss": "^4.1.8",
|
||||
"typescript": "^5.9.3",
|
||||
"typescript": "^5.9.2",
|
||||
"vite-plugin-svgr": "^4.5.0",
|
||||
"vite-tsconfig-paths": "^5.1.4",
|
||||
"vitest": "^4.0.14"
|
||||
"vitest": "^3.0.2"
|
||||
},
|
||||
"packageManager": "npm@10.5.0",
|
||||
"volta": {
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
* - Please do NOT modify this file.
|
||||
*/
|
||||
|
||||
const PACKAGE_VERSION = '2.12.3'
|
||||
const INTEGRITY_CHECKSUM = '4db4a41e972cec1b64cc569c66952d82'
|
||||
const PACKAGE_VERSION = '2.11.1'
|
||||
const INTEGRITY_CHECKSUM = 'f5825c521429caf22a4dd13b66e243af'
|
||||
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
|
||||
const activeClientIds = new Set()
|
||||
|
||||
@@ -71,6 +71,11 @@ addEventListener('message', async function (event) {
|
||||
break
|
||||
}
|
||||
|
||||
case 'MOCK_DEACTIVATE': {
|
||||
activeClientIds.delete(clientId)
|
||||
break
|
||||
}
|
||||
|
||||
case 'CLIENT_CLOSED': {
|
||||
activeClientIds.delete(clientId)
|
||||
|
||||
@@ -89,8 +94,6 @@ addEventListener('message', async function (event) {
|
||||
})
|
||||
|
||||
addEventListener('fetch', function (event) {
|
||||
const requestInterceptedAt = Date.now()
|
||||
|
||||
// Bypass navigation requests.
|
||||
if (event.request.mode === 'navigate') {
|
||||
return
|
||||
@@ -107,29 +110,23 @@ addEventListener('fetch', function (event) {
|
||||
|
||||
// Bypass all requests when there are no active clients.
|
||||
// Prevents the self-unregistered worked from handling requests
|
||||
// after it's been terminated (still remains active until the next reload).
|
||||
// after it's been deleted (still remains active until the next reload).
|
||||
if (activeClientIds.size === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
const requestId = crypto.randomUUID()
|
||||
event.respondWith(handleRequest(event, requestId, requestInterceptedAt))
|
||||
event.respondWith(handleRequest(event, requestId))
|
||||
})
|
||||
|
||||
/**
|
||||
* @param {FetchEvent} event
|
||||
* @param {string} requestId
|
||||
* @param {number} requestInterceptedAt
|
||||
*/
|
||||
async function handleRequest(event, requestId, requestInterceptedAt) {
|
||||
async function handleRequest(event, requestId) {
|
||||
const client = await resolveMainClient(event)
|
||||
const requestCloneForEvents = event.request.clone()
|
||||
const response = await getResponse(
|
||||
event,
|
||||
client,
|
||||
requestId,
|
||||
requestInterceptedAt,
|
||||
)
|
||||
const response = await getResponse(event, client, requestId)
|
||||
|
||||
// Send back the response clone for the "response:*" life-cycle events.
|
||||
// Ensure MSW is active and ready to handle the message, otherwise
|
||||
@@ -205,10 +202,9 @@ async function resolveMainClient(event) {
|
||||
* @param {FetchEvent} event
|
||||
* @param {Client | undefined} client
|
||||
* @param {string} requestId
|
||||
* @param {number} requestInterceptedAt
|
||||
* @returns {Promise<Response>}
|
||||
*/
|
||||
async function getResponse(event, client, requestId, requestInterceptedAt) {
|
||||
async function getResponse(event, client, requestId) {
|
||||
// Clone the request because it might've been already used
|
||||
// (i.e. its body has been read and sent to the client).
|
||||
const requestClone = event.request.clone()
|
||||
@@ -259,7 +255,6 @@ async function getResponse(event, client, requestId, requestInterceptedAt) {
|
||||
type: 'REQUEST',
|
||||
payload: {
|
||||
id: requestId,
|
||||
interceptedAt: requestInterceptedAt,
|
||||
...serializedRequest,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -22,13 +22,6 @@ const getCommandObservationContent = (
|
||||
if (content.length > MAX_CONTENT_LENGTH) {
|
||||
content = `${content.slice(0, MAX_CONTENT_LENGTH)}...`;
|
||||
}
|
||||
|
||||
const command = event.observation === "run" ? event.extras.command : null;
|
||||
|
||||
if (command) {
|
||||
return `Command:\n\`\`\`sh\n${command}\n\`\`\`\n\nOutput:\n\`\`\`sh\n${content.trim() || i18n.t("OBSERVATION$COMMAND_NO_OUTPUT")}\n\`\`\``;
|
||||
}
|
||||
|
||||
return `Output:\n\`\`\`sh\n${content.trim() || i18n.t("OBSERVATION$COMMAND_NO_OUTPUT")}\n\`\`\``;
|
||||
};
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import { I18nKey } from "#/i18n/declaration";
|
||||
import ArrowDown from "#/icons/angle-down-solid.svg?react";
|
||||
import ArrowUp from "#/icons/angle-up-solid.svg?react";
|
||||
import CheckCircle from "#/icons/check-circle-solid.svg?react";
|
||||
import XCircle from "#/icons/x-circle-solid.svg?react";
|
||||
import { OpenHandsAction } from "#/types/core/actions";
|
||||
import { OpenHandsObservation } from "#/types/core/observations";
|
||||
import { cn } from "#/utils/utils";
|
||||
@@ -94,7 +95,7 @@ export function ExpandableMessage({
|
||||
const statusIconClasses = "h-4 w-4 ml-2 inline";
|
||||
|
||||
if (
|
||||
config?.FEATURE_FLAGS?.ENABLE_BILLING &&
|
||||
config?.FEATURE_FLAGS.ENABLE_BILLING &&
|
||||
config?.APP_MODE === "saas" &&
|
||||
id === I18nKey.STATUS$ERROR_LLM_OUT_OF_CREDITS
|
||||
) {
|
||||
@@ -168,12 +169,19 @@ export function ExpandableMessage({
|
||||
)}
|
||||
</button>
|
||||
</span>
|
||||
{type === "action" && success && (
|
||||
{type === "action" && success !== undefined && (
|
||||
<span className="flex-shrink-0">
|
||||
<CheckCircle
|
||||
data-testid="status-icon"
|
||||
className={cn(statusIconClasses, "fill-success")}
|
||||
/>
|
||||
{success ? (
|
||||
<CheckCircle
|
||||
data-testid="status-icon"
|
||||
className={cn(statusIconClasses, "fill-success")}
|
||||
/>
|
||||
) : (
|
||||
<XCircle
|
||||
data-testid="status-icon"
|
||||
className={cn(statusIconClasses, "fill-danger")}
|
||||
/>
|
||||
)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { FaClock } from "react-icons/fa";
|
||||
import CheckCircle from "#/icons/check-circle-solid.svg?react";
|
||||
import XCircle from "#/icons/x-circle-solid.svg?react";
|
||||
import { ObservationResultStatus } from "./event-content-helpers/get-observation-result";
|
||||
|
||||
interface SuccessIndicatorProps {
|
||||
@@ -16,6 +17,13 @@ export function SuccessIndicator({ status }: SuccessIndicatorProps) {
|
||||
/>
|
||||
)}
|
||||
|
||||
{status === "error" && (
|
||||
<XCircle
|
||||
data-testid="status-icon"
|
||||
className="h-4 w-4 ml-2 inline fill-danger"
|
||||
/>
|
||||
)}
|
||||
|
||||
{status === "timeout" && (
|
||||
<FaClock
|
||||
data-testid="status-icon"
|
||||
|
||||
@@ -25,14 +25,7 @@ export function AccountSettingsContextMenu({
|
||||
const { data: config } = useConfig();
|
||||
|
||||
const isSaas = config?.APP_MODE === "saas";
|
||||
|
||||
// Get navigation items and filter out LLM settings if the feature flag is enabled
|
||||
let items = isSaas ? SAAS_NAV_ITEMS : OSS_NAV_ITEMS;
|
||||
if (config?.FEATURE_FLAGS?.HIDE_LLM_SETTINGS) {
|
||||
items = items.filter((item) => item.to !== "/settings");
|
||||
}
|
||||
|
||||
const navItems = items.map((item) => ({
|
||||
const navItems = (isSaas ? SAAS_NAV_ITEMS : OSS_NAV_ITEMS).map((item) => ({
|
||||
...item,
|
||||
icon: React.cloneElement(item.icon, {
|
||||
width: 16,
|
||||
|
||||
+4
-9
@@ -19,10 +19,8 @@ import {
|
||||
} from "#/state/conversation-store";
|
||||
import { ConversationTabsContextMenu } from "./conversation-tabs-context-menu";
|
||||
import { USE_PLANNING_AGENT } from "#/utils/feature-flags";
|
||||
import { useConversationId } from "#/hooks/use-conversation-id";
|
||||
|
||||
export function ConversationTabs() {
|
||||
const { conversationId } = useConversationId();
|
||||
const {
|
||||
selectedTab,
|
||||
isRightPanelShown,
|
||||
@@ -32,21 +30,18 @@ export function ConversationTabs() {
|
||||
|
||||
const [isMenuOpen, setIsMenuOpen] = useState(false);
|
||||
|
||||
// Persist selectedTab and isRightPanelShown in localStorage per conversation
|
||||
// Persist selectedTab and isRightPanelShown in localStorage
|
||||
const [persistedSelectedTab, setPersistedSelectedTab] =
|
||||
useLocalStorage<ConversationTab | null>(
|
||||
`conversation-selected-tab-${conversationId}`,
|
||||
"conversation-selected-tab",
|
||||
"editor",
|
||||
);
|
||||
|
||||
const [persistedIsRightPanelShown, setPersistedIsRightPanelShown] =
|
||||
useLocalStorage<boolean>(
|
||||
`conversation-right-panel-shown-${conversationId}`,
|
||||
true,
|
||||
);
|
||||
useLocalStorage<boolean>("conversation-right-panel-shown", true);
|
||||
|
||||
const [persistedUnpinnedTabs] = useLocalStorage<string[]>(
|
||||
`conversation-unpinned-tabs-${conversationId}`,
|
||||
"conversation-unpinned-tabs",
|
||||
[],
|
||||
);
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user