mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
303 Commits
azure-devo
...
attempt-ci
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7c2029a0e | ||
|
|
997371aed7 | ||
|
|
a1cb0d75af | ||
|
|
0c7b4573c9 | ||
|
|
64e4ef1b15 | ||
|
|
b34c89c0f8 | ||
|
|
d5734a8d0c | ||
|
|
e760c182dc | ||
|
|
71009298af | ||
|
|
48f08cab0e | ||
|
|
475e96c314 | ||
|
|
c5dda5d0d7 | ||
|
|
63086831cb | ||
|
|
0d163bf1ce | ||
|
|
6649be08a7 | ||
|
|
ee0c1a1c2f | ||
|
|
d3c002aee5 | ||
|
|
a2d61e0eb6 | ||
|
|
fab75ab33d | ||
|
|
a8c4fc5318 | ||
|
|
1647a2466f | ||
|
|
fe5b4bb34c | ||
|
|
fb0bfd3684 | ||
|
|
9f5c2327ec | ||
|
|
1864cf9b7a | ||
|
|
9ecf2c7e85 | ||
|
|
0aaad16d35 | ||
|
|
df92923959 | ||
|
|
e18168020a | ||
|
|
a9c76d0ed4 | ||
|
|
3743d10766 | ||
|
|
e8d89d9a55 | ||
|
|
d33c405ed5 | ||
|
|
3db4d3210d | ||
|
|
5ca5bbf3f0 | ||
|
|
b97a4fdee9 | ||
|
|
36a135b942 | ||
|
|
7dff779fce | ||
|
|
f40954f39e | ||
|
|
9b57a0b14f | ||
|
|
00797cd8a1 | ||
|
|
be5cd4c818 | ||
|
|
297140e727 | ||
|
|
8559efa7b2 | ||
|
|
bf06b7e3f3 | ||
|
|
959d610d86 | ||
|
|
16125f2ae9 | ||
|
|
d31950c061 | ||
|
|
db64abc580 | ||
|
|
ed7adb335c | ||
|
|
584517edec | ||
|
|
1a983d2978 | ||
|
|
d7b36c9579 | ||
|
|
72c7d9c497 | ||
|
|
7811a62491 | ||
|
|
4344f5ad4e | ||
|
|
17821f782e | ||
|
|
e1b283886f | ||
|
|
7e5942c2c1 | ||
|
|
1d9cf72e39 | ||
|
|
1d3ed8f6fa | ||
|
|
1aec00e92a | ||
|
|
59ca8bd9a8 | ||
|
|
3a9aa90c3a | ||
|
|
0a98f165e2 | ||
|
|
6ec477dae2 | ||
|
|
517a8c3d9b | ||
|
|
036ef85e9d | ||
|
|
44ef2012df | ||
|
|
cd765937f5 | ||
|
|
d0496fea8c | ||
|
|
8f91db8ec4 | ||
|
|
816d8acf1f | ||
|
|
97e6cb1340 | ||
|
|
cd9a3b02cf | ||
|
|
14695a8f0e | ||
|
|
eaea8b3ce1 | ||
|
|
72555e0f1c | ||
|
|
fd13c91387 | ||
|
|
6139e39449 | ||
|
|
f76ac242f0 | ||
|
|
1f9350320f | ||
|
|
1a3460ba06 | ||
|
|
8f361b3698 | ||
|
|
fd6e0cab3f | ||
|
|
33eec7cb09 | ||
|
|
6c2862ae08 | ||
|
|
6c821ab73e | ||
|
|
96f13b15e7 | ||
|
|
d9731b6850 | ||
|
|
e7e49c9110 | ||
|
|
27590497d5 | ||
|
|
991f1a242c | ||
|
|
6d8cca43a8 | ||
|
|
d62bb81c3b | ||
|
|
156d0686c4 | ||
|
|
d0b1d29379 | ||
|
|
974bcdfd0b | ||
|
|
ed094b6a97 | ||
|
|
49624219ed | ||
|
|
9906a1d49a | ||
|
|
014884333d | ||
|
|
865ddaabdf | ||
|
|
3219834e35 | ||
|
|
2e295073ae | ||
|
|
5ef45cfec2 | ||
|
|
d737141efa | ||
|
|
dec0f411db | ||
|
|
93edf56824 | ||
|
|
77db0cda60 | ||
|
|
d2ff260e39 | ||
|
|
3c59371cbf | ||
|
|
8d4095e20e | ||
|
|
869677c107 | ||
|
|
e3aad64ee6 | ||
|
|
0422ac7ffd | ||
|
|
a8f7ff5142 | ||
|
|
016761471a | ||
|
|
6e61f0617a | ||
|
|
a456be6d7b | ||
|
|
a89d66f934 | ||
|
|
ff170ecee8 | ||
|
|
96e27a8997 | ||
|
|
4da310848c | ||
|
|
d79a9b0764 | ||
|
|
80336b71d6 | ||
|
|
a11fbda85e | ||
|
|
2b73238a45 | ||
|
|
a8988a9564 | ||
|
|
6d5dc76536 | ||
|
|
104e21f501 | ||
|
|
373d7e7708 | ||
|
|
b9533a2811 | ||
|
|
7b8951a761 | ||
|
|
cbe234d5be | ||
|
|
e392d1e7b3 | ||
|
|
16fc633b90 | ||
|
|
fb418448b8 | ||
|
|
8e3c6756ad | ||
|
|
61b8b06ec8 | ||
|
|
3cdc3d5df0 | ||
|
|
179e7dfaf1 | ||
|
|
de21bb5740 | ||
|
|
f9e99b337e | ||
|
|
49d65992fd | ||
|
|
ee62a86ad8 | ||
|
|
0c7d5d4dcd | ||
|
|
18cb38e535 | ||
|
|
37bf855027 | ||
|
|
5894b48c3d | ||
|
|
7fd9704d66 | ||
|
|
c2d6bd8623 | ||
|
|
139d46feff | ||
|
|
45b28cb4ae | ||
|
|
bbc525260c | ||
|
|
bfa4c51ca0 | ||
|
|
65fc2d2d50 | ||
|
|
a0707d5fa2 | ||
|
|
85d867e9af | ||
|
|
7646cabc53 | ||
|
|
8e94924aba | ||
|
|
fb9aa6f76c | ||
|
|
591d32d98a | ||
|
|
8491c38797 | ||
|
|
d66ced3acc | ||
|
|
de91bc86a5 | ||
|
|
26d137c2c3 | ||
|
|
26540e8be1 | ||
|
|
bb2012b768 | ||
|
|
eead092e91 | ||
|
|
64b7ca3faf | ||
|
|
9f8ca567af | ||
|
|
617ea40d00 | ||
|
|
943ab53efa | ||
|
|
2422b1df97 | ||
|
|
9ec47a803f | ||
|
|
021d319db9 | ||
|
|
e82c8d12c2 | ||
|
|
081db2b6b4 | ||
|
|
ac30a73947 | ||
|
|
2f80c468ff | ||
|
|
78b05bf008 | ||
|
|
2fd9cbf8f2 | ||
|
|
dce575fa2d | ||
|
|
69186bc6c8 | ||
|
|
d61b47a134 | ||
|
|
22a3564939 | ||
|
|
61e607fb37 | ||
|
|
544a7b08cd | ||
|
|
99691a6103 | ||
|
|
4a22138fff | ||
|
|
92fb3507c9 | ||
|
|
73d06b2919 | ||
|
|
459f999175 | ||
|
|
e9fe3dcb3b | ||
|
|
c998a4da68 | ||
|
|
ddf45d9b1d | ||
|
|
8f62a97a26 | ||
|
|
ee66151692 | ||
|
|
e6dc590ef1 | ||
|
|
36e2e5942a | ||
|
|
a6096d0b46 | ||
|
|
f107e21d26 | ||
|
|
516591c012 | ||
|
|
9efb67a3bd | ||
|
|
c5ef7a5944 | ||
|
|
20366ba973 | ||
|
|
df03a56888 | ||
|
|
d202c90f5f | ||
|
|
7addb78158 | ||
|
|
8afa6cf51b | ||
|
|
1289688b64 | ||
|
|
e349d37b8c | ||
|
|
6fec7b729d | ||
|
|
cd05434d7f | ||
|
|
9e7b74ea32 | ||
|
|
4646439108 | ||
|
|
f89e41ac30 | ||
|
|
9b0029c5bb | ||
|
|
3f247952fa | ||
|
|
dc360c8a5c | ||
|
|
5f06aad131 | ||
|
|
26ca1cf2d7 | ||
|
|
75c9a09ad1 | ||
|
|
139a5f7caf | ||
|
|
4caa72d080 | ||
|
|
2f2a1c5c58 | ||
|
|
37e0f7fd6e | ||
|
|
b012176c9c | ||
|
|
a5e1a9fd99 | ||
|
|
0b0d77bcdf | ||
|
|
3791a76216 | ||
|
|
b921f06e2b | ||
|
|
07b8391605 | ||
|
|
2ec03b8c55 | ||
|
|
8beb9b4638 | ||
|
|
b40f55a328 | ||
|
|
4e0d553380 | ||
|
|
42c40d75b1 | ||
|
|
6e30c62078 | ||
|
|
f29161b7f3 | ||
|
|
7d084db6d7 | ||
|
|
0ab08e93a6 | ||
|
|
d3586bf820 | ||
|
|
e3dbb00d4e | ||
|
|
e11b2008f3 | ||
|
|
a02b5a6c0e | ||
|
|
3b3b05dc33 | ||
|
|
7d6392f793 | ||
|
|
ec3c33afac | ||
|
|
eb847de7ec | ||
|
|
c3e91baa53 | ||
|
|
d2003c83fb | ||
|
|
7c0a939d96 | ||
|
|
f45b86a396 | ||
|
|
d7bf698d1e | ||
|
|
d655049934 | ||
|
|
6357b46001 | ||
|
|
186f4423e0 | ||
|
|
baf323a26c | ||
|
|
cc7eef9fc0 | ||
|
|
c9a2a6c17f | ||
|
|
2a857a676f | ||
|
|
cf7096e80d | ||
|
|
cfd27b1dce | ||
|
|
c36b628879 | ||
|
|
a34cc6b7e7 | ||
|
|
d70006717e | ||
|
|
bf57a3ac6d | ||
|
|
ffc77fe229 | ||
|
|
82082fcee3 | ||
|
|
8d1f8c24f3 | ||
|
|
0369bc77dd | ||
|
|
1ef111d954 | ||
|
|
69db41aa1d | ||
|
|
a7118ddda6 | ||
|
|
86494cdd90 | ||
|
|
101aa68424 | ||
|
|
47b225d76d | ||
|
|
06758d352a | ||
|
|
6dc6f9514e | ||
|
|
08519c2e44 | ||
|
|
cc1e4b8c4a | ||
|
|
0d6ff3ac50 | ||
|
|
b15ffa29a5 | ||
|
|
5f2ce8e18a | ||
|
|
8f90374f49 | ||
|
|
4c38beb456 | ||
|
|
02f009e6b5 | ||
|
|
fed53185ac | ||
|
|
5cdebc3ed5 | ||
|
|
947fc2f616 | ||
|
|
939242fc22 | ||
|
|
f787f6a089 | ||
|
|
f687bcccf7 | ||
|
|
ba06aa3c0c | ||
|
|
36f516b337 | ||
|
|
3d4805f4b1 | ||
|
|
bf178fcc0e | ||
|
|
7c41d6f30f | ||
|
|
7906b38ded | ||
|
|
d74b0e3fc6 | ||
|
|
07b6ce5ed0 |
@@ -2,7 +2,7 @@ BACKEND_HOST ?= "127.0.0.1"
|
||||
BACKEND_PORT = 3000
|
||||
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
|
||||
FRONTEND_PORT = 3001
|
||||
OPENHANDS_PATH ?= "../../OpenHands"
|
||||
OPENHANDS_PATH ?= ".."
|
||||
OPENHANDS := $(OPENHANDS_PATH)
|
||||
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
)
|
||||
from integrations.v1_utils import get_saas_user_auth
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
@@ -164,8 +165,13 @@ class GithubManager(Manager):
|
||||
)
|
||||
|
||||
if await self.is_job_requested(message):
|
||||
payload = message.message.get('payload', {})
|
||||
user_id = payload['sender']['id']
|
||||
keycloak_user_id = await self.token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITHUB
|
||||
)
|
||||
github_view = await GithubFactory.create_github_view_from_payload(
|
||||
message, self.token_manager
|
||||
message, keycloak_user_id
|
||||
)
|
||||
logger.info(
|
||||
f'[GitHub] Creating job for {github_view.user_info.username} in {github_view.full_repo_name}#{github_view.issue_number}'
|
||||
@@ -282,8 +288,15 @@ class GithubManager(Manager):
|
||||
f'[Github]: Error summarizing issue solvability: {str(e)}'
|
||||
)
|
||||
|
||||
saas_user_auth = await get_saas_user_auth(
|
||||
github_view.user_info.keycloak_user_id, self.token_manager
|
||||
)
|
||||
|
||||
await github_view.create_new_conversation(
|
||||
self.jinja_env, secret_store.provider_tokens, convo_metadata
|
||||
self.jinja_env,
|
||||
secret_store.provider_tokens,
|
||||
convo_metadata,
|
||||
saas_user_auth,
|
||||
)
|
||||
|
||||
conversation_id = github_view.conversation_id
|
||||
@@ -292,18 +305,19 @@ class GithubManager(Manager):
|
||||
f'[GitHub] Created conversation {conversation_id} for user {user_info.username}'
|
||||
)
|
||||
|
||||
# Create a GithubCallbackProcessor
|
||||
processor = GithubCallbackProcessor(
|
||||
github_view=github_view,
|
||||
send_summary_instruction=True,
|
||||
)
|
||||
if not github_view.v1:
|
||||
# Create a GithubCallbackProcessor
|
||||
processor = GithubCallbackProcessor(
|
||||
github_view=github_view,
|
||||
send_summary_instruction=True,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Github] Registered callback processor for conversation {conversation_id}'
|
||||
)
|
||||
logger.info(
|
||||
f'[Github] Registered callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send message with conversation link
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from uuid import uuid4
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from github import Github, GithubIntegration
|
||||
from github.Issue import Issue
|
||||
@@ -8,6 +9,7 @@ from integrations.github.github_types import (
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from integrations.models import Message
|
||||
from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import (
|
||||
ENABLE_PROACTIVE_CONVERSATION_STARTERS,
|
||||
@@ -17,23 +19,32 @@ from integrations.utils import (
|
||||
has_exact_mention,
|
||||
)
|
||||
from jinja2 import Environment
|
||||
from pydantic.dataclasses import dataclass
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
AppConversationStartTaskStatus,
|
||||
)
|
||||
from openhands.app_server.config import get_app_conversation_service
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.server.services.conversation_service import (
|
||||
initialize_conversation,
|
||||
start_conversation,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
@@ -61,19 +72,36 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.enable_proactive_conversation_starters is None:
|
||||
# Check global setting first - if disabled globally, return False
|
||||
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||
return False
|
||||
|
||||
return settings.enable_proactive_conversation_starters
|
||||
def _get_setting():
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
return False
|
||||
return bool(org.enable_proactive_conversation_starters)
|
||||
|
||||
return await call_sync_from_async(_get_setting)
|
||||
|
||||
|
||||
async def get_user_v1_enabled_setting(user_id: str) -> bool:
|
||||
"""Get the user's V1 conversation API setting.
|
||||
|
||||
Args:
|
||||
user_id: The keycloak user ID
|
||||
|
||||
Returns:
|
||||
True if V1 conversations are enabled for this user, False otherwise
|
||||
"""
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
|
||||
if not org or org.v1_enabled is None:
|
||||
return False
|
||||
|
||||
return org.v1_enabled
|
||||
|
||||
|
||||
# =================================================
|
||||
@@ -96,6 +124,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
title: str
|
||||
description: str
|
||||
previous_comments: list[Comment]
|
||||
v1: bool
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
github_service = GithubServiceImpl(
|
||||
@@ -130,6 +159,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
issue_body=self.description,
|
||||
previous_comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
@@ -142,6 +172,19 @@ class GithubIssue(ResolverViewInterface):
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
|
||||
v1_enabled = await get_user_v1_enabled_setting(self.user_info.keycloak_user_id)
|
||||
logger.info(
|
||||
f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {v1_enabled}'
|
||||
)
|
||||
if v1_enabled:
|
||||
# Create dummy conversationm metadata
|
||||
# Don't save to conversation store
|
||||
# V1 conversations are stored in a separate table
|
||||
return ConversationMetadata(
|
||||
conversation_id=uuid4().hex, selected_repository=self.full_repo_name
|
||||
)
|
||||
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
@@ -150,6 +193,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
@@ -158,7 +202,36 @@ class GithubIssue(ResolverViewInterface):
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
v1_enabled = await get_user_v1_enabled_setting(self.user_info.keycloak_user_id)
|
||||
logger.info(
|
||||
f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {v1_enabled}'
|
||||
)
|
||||
if v1_enabled:
|
||||
try:
|
||||
# Use V1 app conversation service
|
||||
await self._create_v1_conversation(
|
||||
jinja_env, saas_user_auth, conversation_metadata
|
||||
)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Error checking V1 settings, falling back to V0: {e}')
|
||||
|
||||
# Use existing V0 conversation service
|
||||
await self._create_v0_conversation(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
|
||||
async def _create_v0_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
"""Create conversation using the legacy V0 system."""
|
||||
logger.info('[GitHub V1]: Creating V0 conversation')
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
@@ -177,6 +250,78 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_instructions=conversation_instructions,
|
||||
)
|
||||
|
||||
async def _create_v1_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
saas_user_auth: UserAuth,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
logger.info('[GitHub V1]: Creating V1 conversation')
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
|
||||
# Create the initial message request
|
||||
initial_message = SendMessageRequest(
|
||||
role='user', content=[TextContent(text=user_instructions)]
|
||||
)
|
||||
|
||||
# Create the GitHub V1 callback processor
|
||||
github_callback_processor = self._create_github_v1_callback_processor()
|
||||
|
||||
# Get the app conversation service and start the conversation
|
||||
injector_state = InjectorState()
|
||||
|
||||
# Create the V1 conversation start request with the callback processor
|
||||
start_request = AppConversationStartRequest(
|
||||
conversation_id=UUID(conversation_metadata.conversation_id),
|
||||
system_message_suffix=conversation_instructions,
|
||||
initial_message=initial_message,
|
||||
selected_repository=self.full_repo_name,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title=f'GitHub Issue #{self.issue_number}: {self.title}',
|
||||
trigger=ConversationTrigger.RESOLVER,
|
||||
processors=[
|
||||
github_callback_processor
|
||||
], # Pass the callback processor directly
|
||||
)
|
||||
|
||||
# Set up the GitHub user context for the V1 system
|
||||
github_user_context = ResolverUserContext(saas_user_auth=saas_user_auth)
|
||||
setattr(injector_state, USER_CONTEXT_ATTR, github_user_context)
|
||||
|
||||
async with get_app_conversation_service(
|
||||
injector_state
|
||||
) as app_conversation_service:
|
||||
async for task in app_conversation_service.start_app_conversation(
|
||||
start_request
|
||||
):
|
||||
if task.status == AppConversationStartTaskStatus.ERROR:
|
||||
logger.error(f'Failed to start V1 conversation: {task.detail}')
|
||||
raise RuntimeError(
|
||||
f'Failed to start V1 conversation: {task.detail}'
|
||||
)
|
||||
|
||||
self.v1 = True
|
||||
|
||||
def _create_github_v1_callback_processor(self):
|
||||
"""Create a V1 callback processor for GitHub integration."""
|
||||
from openhands.app_server.event_callback.github_v1_callback_processor import (
|
||||
GithubV1CallbackProcessor,
|
||||
)
|
||||
|
||||
# Create and return the GitHub V1 callback processor
|
||||
return GithubV1CallbackProcessor(
|
||||
github_view_data={
|
||||
'issue_number': self.issue_number,
|
||||
'full_repo_name': self.full_repo_name,
|
||||
'installation_id': self.installation_id,
|
||||
},
|
||||
send_summary_instruction=self.send_summary_instruction,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubIssueComment(GithubIssue):
|
||||
@@ -195,7 +340,6 @@ class GithubIssueComment(GithubIssue):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
@@ -232,8 +376,7 @@ class GithubPRComment(GithubIssueComment):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation(
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
selected_repository=self.full_repo_name,
|
||||
@@ -279,7 +422,6 @@ class GithubInlinePRComment(GithubPRComment):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'pr_update_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
pr_number=self.issue_number,
|
||||
pr_title=self.title,
|
||||
@@ -292,6 +434,24 @@ class GithubInlinePRComment(GithubPRComment):
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
def _create_github_v1_callback_processor(self):
|
||||
"""Create a V1 callback processor for GitHub integration."""
|
||||
from openhands.app_server.event_callback.github_v1_callback_processor import (
|
||||
GithubV1CallbackProcessor,
|
||||
)
|
||||
|
||||
# Create and return the GitHub V1 callback processor
|
||||
return GithubV1CallbackProcessor(
|
||||
github_view_data={
|
||||
'issue_number': self.issue_number,
|
||||
'full_repo_name': self.full_repo_name,
|
||||
'installation_id': self.installation_id,
|
||||
'comment_id': self.comment_id,
|
||||
},
|
||||
inline_pr_comment=True,
|
||||
send_summary_instruction=self.send_summary_instruction,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubFailingAction:
|
||||
@@ -605,7 +765,7 @@ class GithubFactory:
|
||||
|
||||
@staticmethod
|
||||
async def create_github_view_from_payload(
|
||||
message: Message, token_manager: TokenManager
|
||||
message: Message, keycloak_user_id: str
|
||||
) -> ResolverViewInterface:
|
||||
"""Create the appropriate class (GithubIssue or GithubPRComment) based on the payload.
|
||||
Also return metadata about the event (e.g., action type).
|
||||
@@ -615,17 +775,10 @@ class GithubFactory:
|
||||
user_id = payload['sender']['id']
|
||||
username = payload['sender']['login']
|
||||
|
||||
keyloak_user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITHUB
|
||||
)
|
||||
|
||||
if keyloak_user_id is None:
|
||||
logger.warning(f'Got invalid keyloak user id for GitHub User {user_id} ')
|
||||
|
||||
selected_repo = GithubFactory.get_full_repo_name(repo_obj)
|
||||
is_public_repo = not repo_obj.get('private', True)
|
||||
user_info = UserData(
|
||||
user_id=user_id, username=username, keycloak_user_id=keyloak_user_id
|
||||
user_id=user_id, username=username, keycloak_user_id=keycloak_user_id
|
||||
)
|
||||
|
||||
installation_id = message.message['installation']
|
||||
@@ -649,6 +802,7 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_issue_comment(message):
|
||||
@@ -674,6 +828,7 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_pr_comment(message):
|
||||
@@ -715,6 +870,7 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
elif GithubFactory.is_inline_pr_comment(message):
|
||||
@@ -748,6 +904,7 @@ class GithubFactory:
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
63
enterprise/integrations/resolver_context.py
Normal file
63
enterprise/integrations/resolver_context.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.app_server.user.user_models import UserInfo
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.conversation.secret_source import SecretSource, StaticSecret
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class ResolverUserContext(UserContext):
|
||||
"""User context for resolver operations that inherits from UserContext."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
self.saas_user_auth = saas_user_auth
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return await self.saas_user_auth.get_user_id()
|
||||
|
||||
async def get_user_info(self) -> UserInfo:
|
||||
user_settings = await self.saas_user_auth.get_user_settings()
|
||||
user_id = await self.saas_user_auth.get_user_id()
|
||||
if user_settings:
|
||||
return UserInfo(
|
||||
id=user_id,
|
||||
**user_settings.model_dump(context={'expose_secrets': True}),
|
||||
)
|
||||
|
||||
return UserInfo(id=user_id)
|
||||
|
||||
async def get_authenticated_git_url(self, repository: str) -> str:
|
||||
# This would need to be implemented based on the git provider tokens
|
||||
# For now, return a basic HTTPS URL
|
||||
return f'https://github.com/{repository}.git'
|
||||
|
||||
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
||||
# Return the appropriate token from git_provider_tokens
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens:
|
||||
return provider_tokens.get(provider_type)
|
||||
return None
|
||||
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
return await self.saas_user_auth.get_provider_tokens()
|
||||
|
||||
async def get_secrets(self) -> dict[str, SecretSource]:
|
||||
"""Get secrets for the user, including custom secrets."""
|
||||
secrets = await self.saas_user_auth.get_secrets()
|
||||
if secrets:
|
||||
# Convert custom secrets to StaticSecret objects for SDK compatibility
|
||||
# secrets.custom_secrets is of type Mapping[str, CustomSecret]
|
||||
converted_secrets = {}
|
||||
for key, custom_secret in secrets.custom_secrets.items():
|
||||
# Extract the secret value from CustomSecret and convert to StaticSecret
|
||||
secret_value = custom_secret.secret.get_secret_value()
|
||||
converted_secrets[key] = StaticSecret(value=secret_value)
|
||||
return converted_secrets
|
||||
return {}
|
||||
|
||||
async def get_mcp_api_key(self) -> str | None:
|
||||
return await self.saas_user_auth.get_mcp_api_key()
|
||||
@@ -167,6 +167,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
'channel_id': self.channel_id,
|
||||
'conversation_id': self.conversation_id,
|
||||
'keycloak_user_id': user_info.keycloak_user_id,
|
||||
'org_id': user_info.org_id,
|
||||
'parent_id': self.thread_ts or self.message_ts,
|
||||
},
|
||||
)
|
||||
@@ -174,6 +175,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
conversation_id=self.conversation_id,
|
||||
channel_id=self.channel_id,
|
||||
keycloak_user_id=user_info.keycloak_user_id,
|
||||
org_id=user_info.org_id,
|
||||
parent_id=self.thread_ts
|
||||
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
|
||||
)
|
||||
@@ -304,10 +306,10 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg_action = MessageAction(content=user_msg)
|
||||
instructions, _ = self._get_instructions(jinja)
|
||||
user_msg = MessageAction(content=instructions)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg_action)
|
||||
self.conversation_id, event_to_dict(user_msg)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
from uuid import UUID
|
||||
|
||||
import stripe
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.filter(StripeCustomer.org_id == org_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer:
|
||||
@@ -21,46 +26,76 @@ async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
|
||||
# If that fails, fallback to stripe
|
||||
search_result = await stripe.Customer.search_async(
|
||||
query=f"metadata['user_id']:'{user_id}'",
|
||||
query=f"metadata['org_id']:'{str(org_id)}'",
|
||||
)
|
||||
data = search_result.data
|
||||
if not data:
|
||||
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
|
||||
logger.info(
|
||||
'no_customer_for_org_id',
|
||||
extra={'org_id': str(org_id)},
|
||||
)
|
||||
return None
|
||||
return data[0].id # type: ignore [attr-defined]
|
||||
|
||||
|
||||
async def find_or_create_customer(user_id: str) -> str:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id:
|
||||
return customer_id
|
||||
logger.info('creating_customer', extra={'user_id': user_id})
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
return customer_id
|
||||
|
||||
# Get the user info from keycloak
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
|
||||
|
||||
async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
# Get the current org for the user
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
if customer_id:
|
||||
return {'customer_id': customer_id, 'org_id': str(org.id)}
|
||||
logger.info(
|
||||
'creating_customer',
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
# Create the customer in stripe
|
||||
customer = await stripe.Customer.create_async(
|
||||
email=str(user_info.get('email', '')),
|
||||
metadata={'user_id': user_id},
|
||||
email=org.contact_email,
|
||||
metadata={'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
|
||||
StripeCustomer(
|
||||
keycloak_user_id=user_id,
|
||||
org_id=org.id,
|
||||
stripe_customer_id=customer.id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
return customer.id
|
||||
return {'customer_id': customer.id, 'org_id': str(org.id)}
|
||||
|
||||
|
||||
async def has_payment_method(user_id: str) -> bool:
|
||||
async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id is None:
|
||||
return False
|
||||
@@ -71,3 +106,28 @@ async def has_payment_method(user_id: str) -> bool:
|
||||
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
|
||||
)
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(session: Session, user_id: str, org: Org):
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ class PRStatus(Enum):
|
||||
class UserData(BaseModel):
|
||||
user_id: int
|
||||
username: str
|
||||
keycloak_user_id: str | None
|
||||
keycloak_user_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
20
enterprise/integrations/v1_utils.py
Normal file
20
enterprise/integrations/v1_utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
async def get_saas_user_auth(
|
||||
keycloak_user_id: str, token_manager: TokenManager
|
||||
) -> UserAuth:
|
||||
offline_token = await token_manager.load_offline_token(keycloak_user_id)
|
||||
if offline_token is None:
|
||||
logger.info('no_offline_token_found')
|
||||
|
||||
user_auth = SaasUserAuth(
|
||||
user_id=keycloak_user_id,
|
||||
refresh_token=SecretStr(offline_token),
|
||||
)
|
||||
return user_auth
|
||||
@@ -20,6 +20,8 @@ down_revision = '059'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# TODO: decide whether to modify this for orgs or users
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""
|
||||
@@ -28,8 +30,10 @@ def upgrade():
|
||||
|
||||
This replaces the functionality of the removed admin maintenance endpoint.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
|
||||
# Hardcoded value to prevent migration failures when constant is removed from codebase
|
||||
# This migration has already run in production, so we use the value that was current at the time
|
||||
CURRENT_USER_SETTINGS_VERSION = 4
|
||||
|
||||
# Create a connection and bind it to a session
|
||||
connection = op.get_bind()
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Add v1_enabled column to user_settings
|
||||
|
||||
Revision ID: 083
|
||||
Revises: 082
|
||||
Create Date: 2025-11-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '083'
|
||||
down_revision: Union[str, None] = '082'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add v1_enabled column to user_settings table."""
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column(
|
||||
'v1_enabled',
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove v1_enabled column from user_settings table."""
|
||||
op.drop_column('user_settings', 'v1_enabled')
|
||||
272
enterprise/migrations/versions/084_create_org_tables.py
Normal file
272
enterprise/migrations/versions/084_create_org_tables.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""create org tables from pgerd schema
|
||||
|
||||
Revision ID: 084
|
||||
Revises: 083
|
||||
Create Date: 2025-01-07 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '084'
|
||||
down_revision: Union[str, None] = '083'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove current settings table
|
||||
op.execute('DROP TABLE IF EXISTS settings')
|
||||
|
||||
# Add already_migrated column to user_settings table
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column(
|
||||
'already_migrated',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
# Create role table
|
||||
op.create_table(
|
||||
'role',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('rank', sa.Integer, nullable=False),
|
||||
sa.UniqueConstraint('name', name='role_name_unique'),
|
||||
)
|
||||
|
||||
# 1. Create default roles
|
||||
op.execute(
|
||||
sa.text("""
|
||||
INSERT INTO role (name, rank) VALUES ('owner', 10), ('admin', 20), ('user', 1000)
|
||||
ON CONFLICT (name) DO NOTHING;
|
||||
""")
|
||||
)
|
||||
|
||||
# Create org table with settings fields
|
||||
op.create_table(
|
||||
'org',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('contact_name', sa.String, nullable=True),
|
||||
sa.Column('contact_email', sa.String, nullable=True),
|
||||
sa.Column('conversation_expiration', sa.Integer, nullable=True),
|
||||
# Settings fields moved to org table
|
||||
sa.Column('agent', sa.String, nullable=True),
|
||||
sa.Column('default_max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('security_analyzer', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'confirmation_mode',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('default_llm_model', sa.String, nullable=True),
|
||||
sa.Column('_default_llm_api_key_for_byor', sa.String, nullable=True),
|
||||
sa.Column('default_llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('remote_runtime_resource_factor', sa.Integer, nullable=True),
|
||||
sa.Column(
|
||||
'enable_default_condenser',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('billing_margin', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_proactive_conversation_starters',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('sandbox_base_container_image', sa.String, nullable=True),
|
||||
sa.Column('sandbox_runtime_container_image', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'org_version', sa.Integer, nullable=False, server_default=sa.text('0')
|
||||
),
|
||||
sa.Column('mcp_config', sa.JSON, nullable=True),
|
||||
sa.Column('_search_api_key', sa.String, nullable=True),
|
||||
sa.Column('_sandbox_api_key', sa.String, nullable=True),
|
||||
sa.Column('max_budget_per_task', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_solvability_analysis',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('v1_enabled', sa.Boolean, nullable=True),
|
||||
sa.UniqueConstraint('name', name='org_name_unique'),
|
||||
)
|
||||
|
||||
# Create user table with user-specific settings fields
|
||||
op.create_table(
|
||||
'user',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('current_org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=True),
|
||||
sa.Column('accepted_tos', sa.DateTime, nullable=True),
|
||||
sa.Column(
|
||||
'enable_sound_notifications',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('language', sa.String, nullable=True),
|
||||
sa.Column('user_consents_to_analytics', sa.Boolean, nullable=True),
|
||||
sa.Column('email', sa.String, nullable=True),
|
||||
sa.Column('email_verified', sa.Boolean, nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
['current_org_id'], ['org.id'], name='current_org_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='user_role_fkey'),
|
||||
)
|
||||
|
||||
# Create org_member table (junction table for many-to-many relationship)
|
||||
op.create_table(
|
||||
'org_member',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('_llm_api_key', sa.String, nullable=False),
|
||||
sa.Column('max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('llm_model', sa.String, nullable=True),
|
||||
sa.Column('_llm_api_key_for_byor', sa.String, nullable=True),
|
||||
sa.Column('llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('status', sa.String, nullable=True),
|
||||
sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='om_org_fkey'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='om_user_fkey'),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='om_role_fkey'),
|
||||
sa.PrimaryKeyConstraint('org_id', 'user_id'),
|
||||
)
|
||||
|
||||
# Add org_id column to existing tables
|
||||
# billing_sessions
|
||||
op.add_column(
|
||||
'billing_sessions',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# Create conversation_metadata_saas table
|
||||
op.create_table(
|
||||
'conversation_metadata_saas',
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'
|
||||
),
|
||||
sa.PrimaryKeyConstraint('conversation_id'),
|
||||
)
|
||||
|
||||
# custom_secrets
|
||||
op.add_column(
|
||||
'custom_secrets',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'custom_secrets_org_fkey', 'custom_secrets', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# api_keys
|
||||
op.add_column(
|
||||
'api_keys', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key('api_keys_org_fkey', 'api_keys', 'org', ['org_id'], ['id'])
|
||||
|
||||
# slack_conversation
|
||||
op.add_column(
|
||||
'slack_conversation',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# slack_users
|
||||
op.add_column(
|
||||
'slack_users', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_users_org_fkey', 'slack_users', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# stripe_customers
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
op.add_column(
|
||||
'stripe_customers',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop already_migrated column from user_settings table
|
||||
op.drop_column('user_settings', 'already_migrated')
|
||||
|
||||
# Drop foreign keys and columns added to existing tables
|
||||
op.drop_constraint(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('stripe_customers', 'org_id')
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.drop_constraint('slack_users_org_fkey', 'slack_users', type_='foreignkey')
|
||||
op.drop_column('slack_users', 'org_id')
|
||||
|
||||
op.drop_constraint(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('slack_conversation', 'org_id')
|
||||
|
||||
op.drop_constraint('api_keys_org_fkey', 'api_keys', type_='foreignkey')
|
||||
op.drop_column('api_keys', 'org_id')
|
||||
|
||||
op.drop_constraint('custom_secrets_org_fkey', 'custom_secrets', type_='foreignkey')
|
||||
op.drop_column('custom_secrets', 'org_id')
|
||||
|
||||
# Drop conversation_metadata_saas table
|
||||
op.drop_table('conversation_metadata_saas')
|
||||
|
||||
op.drop_constraint(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('billing_sessions', 'org_id')
|
||||
|
||||
# Drop tables in reverse order due to foreign key constraints
|
||||
op.drop_table('org_member')
|
||||
op.drop_table('user')
|
||||
op.drop_table('org')
|
||||
op.drop_table('role')
|
||||
10031
enterprise/poetry.lock
generated
10031
enterprise/poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -4,6 +4,10 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Ensure SAAS configuration is used
|
||||
if not os.getenv('OPENHANDS_CONFIG_CLS'):
|
||||
os.environ['OPENHANDS_CONFIG_CLS'] = 'server.config.SaaSServerConfig'
|
||||
|
||||
import socketio # noqa: E402
|
||||
from fastapi import Request, status # noqa: E402
|
||||
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
|
||||
|
||||
@@ -102,7 +102,6 @@ class SaasUserAuth(UserAuth):
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
settings = await settings_store.load()
|
||||
# If load() returned None, should settings be created?
|
||||
if settings:
|
||||
settings.email = self.email
|
||||
settings.email_verified = self.email_verified
|
||||
@@ -203,6 +202,15 @@ class SaasUserAuth(UserAuth):
|
||||
self.settings_store = settings_store
|
||||
return settings_store
|
||||
|
||||
async def get_mcp_api_key(self) -> str:
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
mcp_api_key = api_key_store.retrieve_mcp_api_key(self.user_id)
|
||||
if not mcp_api_key:
|
||||
mcp_api_key = api_key_store.create_api_key(
|
||||
self.user_id, 'MCP_API_KEY', None
|
||||
)
|
||||
return mcp_api_key
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
logger.debug('saas_user_auth_get_instance')
|
||||
@@ -243,7 +251,12 @@ def get_api_key_from_header(request: Request):
|
||||
# This is a temp hack
|
||||
# Streamable HTTP MCP Client works via redirect requests, but drops the Authorization header for reason
|
||||
# We include `X-Session-API-Key` header by default due to nested runtimes, so it used as a drop in replacement here
|
||||
return request.headers.get('X-Session-API-Key')
|
||||
session_api_key = request.headers.get('X-Session-API-Key')
|
||||
if session_api_key:
|
||||
return session_api_key
|
||||
|
||||
# Fallback to X-Access-Token header as an additional option
|
||||
return request.headers.get('X-Access-Token')
|
||||
|
||||
|
||||
async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
|
||||
@@ -9,7 +9,7 @@ from server.logger import logger
|
||||
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -525,16 +525,18 @@ class ClusteredConversationManager(StandaloneConversationManager):
|
||||
)
|
||||
# Look up the user_id from the database
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_id = (
|
||||
conversation_metadata.user_id if conversation_metadata else None
|
||||
str(conversation_metadata_saas.user_id)
|
||||
if conversation_metadata_saas
|
||||
else None
|
||||
)
|
||||
# Handle the stopped conversation asynchronously
|
||||
asyncio.create_task(
|
||||
|
||||
@@ -19,8 +19,8 @@ IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
|
||||
|
||||
# Map of user settings versions to their corresponding default LLM models
|
||||
# This ensures that CURRENT_USER_SETTINGS_VERSION and LITELLM_DEFAULT_MODEL stay in sync
|
||||
USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
# This ensures that PERSONAL_WORKSPACE_VERSION_TO_MODEL and LITELLM_DEFAULT_MODEL stay in sync
|
||||
PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
1: 'claude-3-5-sonnet-20241022',
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
@@ -30,29 +30,17 @@ USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
|
||||
# Current user settings version - this should be the latest key in USER_SETTINGS_VERSION_TO_MODEL
|
||||
CURRENT_USER_SETTINGS_VERSION = max(USER_SETTINGS_VERSION_TO_MODEL.keys())
|
||||
ORG_SETTINGS_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
PERSONAL_WORKSPACE_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
|
||||
LITE_LLM_API_URL = os.environ.get(
|
||||
'LITE_LLM_API_URL', 'https://llm-proxy.app.all-hands.dev'
|
||||
)
|
||||
LITE_LLM_TEAM_ID = os.environ.get('LITE_LLM_TEAM_ID', None)
|
||||
LITE_LLM_API_KEY = os.environ.get('LITE_LLM_API_KEY', None)
|
||||
SUBSCRIPTION_PRICE_DATA = {
|
||||
'MONTHLY_SUBSCRIPTION': {
|
||||
'unit_amount': 2000,
|
||||
'currency': 'usd',
|
||||
'product_data': {
|
||||
'name': 'OpenHands Monthly',
|
||||
'tax_code': 'txcd_10000000',
|
||||
},
|
||||
'tax_behavior': 'exclusive',
|
||||
'recurring': {'interval': 'month', 'interval_count': 1},
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
STRIPE_WEBHOOK_SECRET = os.environ.get('STRIPE_WEBHOOK_SECRET', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
SLACK_CLIENT_ID = os.environ.get('SLACK_CLIENT_ID', None)
|
||||
@@ -102,5 +90,5 @@ def get_default_litellm_model():
|
||||
"""
|
||||
if LITELLM_DEFAULT_MODEL:
|
||||
return LITELLM_DEFAULT_MODEL
|
||||
model = USER_SETTINGS_VERSION_TO_MODEL[CURRENT_USER_SETTINGS_VERSION]
|
||||
model = PERSONAL_WORKSPACE_VERSION_TO_MODEL[PERSONAL_WORKSPACE_VERSION]
|
||||
return build_litellm_proxy_model_path(model)
|
||||
|
||||
@@ -44,11 +44,13 @@ class MyProcessor(MaintenanceTaskProcessor):
|
||||
### UserVersionUpgradeProcessor
|
||||
|
||||
Located in `user_version_upgrade_processor.py`, this processor:
|
||||
|
||||
- Handles up to 100 user IDs per task
|
||||
- Upgrades users with `user_version < CURRENT_USER_SETTINGS_VERSION`
|
||||
- Upgrades users with `user_version < ORG_SETTINGS_VERSION`
|
||||
- Uses `SaasSettingsStore.create_default_settings()` for upgrades
|
||||
|
||||
**Usage:**
|
||||
|
||||
```python
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import UserVersionUpgradeProcessor
|
||||
|
||||
@@ -144,22 +146,26 @@ task = create_maintenance_task(
|
||||
## Best Practices
|
||||
|
||||
### Processor Design
|
||||
|
||||
- Keep tasks short-running (under 1 minute)
|
||||
- Handle errors gracefully and return meaningful error information
|
||||
- Use batch processing for large datasets
|
||||
- Include progress information in the return dict
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Always wrap your processor logic in try-catch blocks
|
||||
- Return structured error information
|
||||
- Log important events for debugging
|
||||
|
||||
### Performance
|
||||
|
||||
- Limit batch sizes to avoid long-running tasks
|
||||
- Use database sessions efficiently
|
||||
- Consider memory usage for large datasets
|
||||
|
||||
### Testing
|
||||
|
||||
- Create unit tests for your processors
|
||||
- Test error conditions
|
||||
- Verify the processor serialization/deserialization works correctly
|
||||
@@ -167,6 +173,7 @@ task = create_maintenance_task(
|
||||
## Database Patterns
|
||||
|
||||
The maintenance task system follows the repository's established patterns:
|
||||
|
||||
- Uses `session_maker()` for database operations
|
||||
- Wraps sync database operations in `call_sync_from_async` for async routes
|
||||
- Follows proper SQLAlchemy query patterns
|
||||
@@ -174,15 +181,18 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Integration with Existing Systems
|
||||
|
||||
### User Management
|
||||
|
||||
- Integrates with the existing `UserSettings` model
|
||||
- Uses the current user versioning system (`CURRENT_USER_SETTINGS_VERSION`)
|
||||
- Uses the current user versioning system (`ORG_SETTINGS_VERSION`)
|
||||
- Maintains compatibility with existing user management workflows
|
||||
|
||||
### Authentication
|
||||
|
||||
- Admin endpoints use the existing SaaS authentication system
|
||||
- Requires users to have `admin = True` in their UserSettings
|
||||
|
||||
### Monitoring
|
||||
|
||||
- Tasks are logged with structured information
|
||||
- Status updates are tracked in the database
|
||||
- Error information is preserved for debugging
|
||||
@@ -206,6 +216,7 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements that could be added:
|
||||
|
||||
- Task dependencies and scheduling
|
||||
- Retry mechanisms for failed tasks
|
||||
- Real-time progress updates
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.logger import logger
|
||||
from storage.database import session_maker
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskProcessor
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
|
||||
|
||||
class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
|
||||
"""
|
||||
Processor for upgrading user settings to the current version.
|
||||
|
||||
This processor takes a list of user IDs and upgrades any users
|
||||
whose user_version is less than CURRENT_USER_SETTINGS_VERSION.
|
||||
"""
|
||||
|
||||
user_ids: List[str]
|
||||
|
||||
async def __call__(self, task: MaintenanceTask) -> dict:
|
||||
"""
|
||||
Process user version upgrades for the specified user IDs.
|
||||
|
||||
Args:
|
||||
task: The maintenance task being processed
|
||||
|
||||
Returns:
|
||||
dict: Results containing successful and failed user IDs
|
||||
"""
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:start',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_count': len(self.user_ids),
|
||||
'current_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
if len(self.user_ids) > 100:
|
||||
raise ValueError(
|
||||
f'Too many user IDs: {len(self.user_ids)}. Maximum is 100.'
|
||||
)
|
||||
|
||||
config = load_openhands_config()
|
||||
|
||||
# Track results
|
||||
successful_upgrades = []
|
||||
failed_upgrades = []
|
||||
users_already_current = []
|
||||
|
||||
# Find users that need upgrading
|
||||
with session_maker() as session:
|
||||
users_to_upgrade = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id.in_(self.user_ids),
|
||||
UserSettings.user_version < CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Track users that are already current
|
||||
users_needing_upgrade_ids = {u.keycloak_user_id for u in users_to_upgrade}
|
||||
users_already_current = [
|
||||
uid for uid in self.user_ids if uid not in users_needing_upgrade_ids
|
||||
]
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:found_users',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'users_to_upgrade': len(users_to_upgrade),
|
||||
'users_already_current': len(users_already_current),
|
||||
'total_requested': len(self.user_ids),
|
||||
},
|
||||
)
|
||||
|
||||
# Process each user that needs upgrading
|
||||
for user_settings in users_to_upgrade:
|
||||
user_id = user_settings.keycloak_user_id
|
||||
old_version = user_settings.user_version
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:upgrading_user',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
# Create SaasSettingsStore instance and upgrade
|
||||
settings_store = await SaasSettingsStore.get_instance(config, user_id)
|
||||
await settings_store.create_default_settings(user_settings)
|
||||
|
||||
successful_upgrades.append(
|
||||
{
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:user_upgraded',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_upgrades.append(
|
||||
{'user_id': user_id, 'old_version': old_version, 'error': str(e)}
|
||||
)
|
||||
|
||||
logger.error(
|
||||
'user_version_upgrade_processor:user_upgrade_failed',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
# Create result summary
|
||||
result = {
|
||||
'total_users': len(self.user_ids),
|
||||
'users_already_current': users_already_current,
|
||||
'successful_upgrades': successful_upgrades,
|
||||
'failed_upgrades': failed_upgrades,
|
||||
'summary': (
|
||||
f'Processed {len(self.user_ids)} users: '
|
||||
f'{len(successful_upgrades)} upgraded, '
|
||||
f'{len(users_already_current)} already current, '
|
||||
f'{len(failed_upgrades)} errors'
|
||||
),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:completed',
|
||||
extra={'task_id': task.id, 'result': result},
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
|
||||
@@ -36,6 +34,7 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
|
||||
Returns:
|
||||
A tuple containing the default SSE server configuration and a list of MCP stdio server configurations
|
||||
"""
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
if user_id:
|
||||
|
||||
@@ -1,109 +1,97 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.config import get_config
|
||||
from server.constants import LITE_LLM_API_KEY, LITE_LLM_API_URL
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_store import OrgStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
# Helper functions for BYOR API key management
|
||||
async def get_byor_key_from_db(user_id: str) -> str | None:
|
||||
"""Get the BYOR key from the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
user_db_settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
if user_db_settings and user_db_settings.llm_api_key_for_byor:
|
||||
return user_db_settings.llm_api_key_for_byor
|
||||
return None
|
||||
def _get_byor_key():
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
if current_org_member.llm_api_key_for_byor:
|
||||
return current_org_member.llm_api_key_for_byor.get_secret_value()
|
||||
|
||||
org = OrgStore.get_org_by_id(current_org_id)
|
||||
if not org:
|
||||
return None
|
||||
return (
|
||||
org.default_llm_api_key_for_byor.get_secret_value()
|
||||
if org.default_llm_api_key_for_byor
|
||||
else None
|
||||
)
|
||||
|
||||
return await call_sync_from_async(_get_byor_key)
|
||||
|
||||
|
||||
async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
"""Store the BYOR key in the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
def _update_user_settings():
|
||||
with session_maker() as session:
|
||||
user_db_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
if user_db_settings:
|
||||
user_db_settings.llm_api_key_for_byor = key
|
||||
session.commit()
|
||||
logger.info(
|
||||
'Successfully stored BYOR key in user settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'User settings not found when trying to store BYOR key',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
current_org_member.llm_api_key_for_byor = key
|
||||
OrgMemberStore.update_org_member(current_org_member)
|
||||
|
||||
await call_sync_from_async(_update_user_settings)
|
||||
|
||||
|
||||
async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
key = await LiteLlmManager.generate_key(
|
||||
user_id, user_id, f'BYOR Key - user {user_id}', {'type': 'byor'}
|
||||
)
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'metadata': {'type': 'byor'},
|
||||
'key_alias': f'BYOR Key - user {user_id}',
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...' if key and len(key) > 10 else key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...'
|
||||
if key and len(key) > 10
|
||||
else key,
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error generating BYOR key',
|
||||
@@ -114,30 +102,14 @@ async def generate_byor_key(user_id: str) -> str | None:
|
||||
|
||||
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
|
||||
"""Delete the BYOR key from LiteLLM using the key directly."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
# Delete the key directly using the key value
|
||||
delete_url = f'{LITE_LLM_API_URL}/key/delete'
|
||||
delete_payload = {'keys': [byor_key]}
|
||||
|
||||
delete_response = await client.post(delete_url, json=delete_payload)
|
||||
delete_response.raise_for_status()
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
await LiteLlmManager.delete_key(byor_key)
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error deleting BYOR key from LiteLLM',
|
||||
@@ -315,15 +287,6 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
|
||||
|
||||
try:
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='LiteLLM API configuration not found',
|
||||
)
|
||||
|
||||
# Get the existing BYOR key from the database
|
||||
existing_byor_key = await get_byor_key_from_db(user_id)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
@@ -17,12 +18,12 @@ from server.auth.constants import (
|
||||
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config, sign_token
|
||||
from server.config import sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
@@ -31,7 +32,7 @@ from openhands.server.services.conversation_service import create_provider_token
|
||||
from openhands.server.shared import config
|
||||
from openhands.server.user_auth import get_access_token
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
from openhands.utils.posthog_tracker import track_user_signup_completed
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
@@ -83,7 +84,8 @@ def get_cookie_domain(request: Request) -> str | None:
|
||||
# for now just use the full hostname except for staging stacks.
|
||||
return (
|
||||
None
|
||||
if (request.url.hostname or '').endswith('staging.all-hand.dev')
|
||||
if not request.url.hostname
|
||||
or request.url.hostname.endswith('staging.all-hands.dev')
|
||||
else request.url.hostname
|
||||
)
|
||||
|
||||
@@ -147,6 +149,21 @@ async def keycloak_callback(
|
||||
)
|
||||
|
||||
user_id = user_info['sub']
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
|
||||
if not user:
|
||||
user = await UserStore.create_user(user_id, user_info)
|
||||
|
||||
if not user:
|
||||
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': f'Failed to authenticate user {user_info["preferred_username"]}'
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f'Logging in user {str(user.id)} in org {user.current_org_id}')
|
||||
|
||||
# default to github IDP for now.
|
||||
# TODO: remove default once Keycloak is updated universally with the new attribute.
|
||||
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
|
||||
@@ -221,15 +238,7 @@ async def keycloak_callback(
|
||||
f'&state={state}'
|
||||
)
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(user_id)
|
||||
has_accepted_tos = (
|
||||
user_settings is not None and user_settings.accepted_tos is not None
|
||||
)
|
||||
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
encoded_redirect_url = quote(redirect_url, safe='')
|
||||
@@ -347,34 +356,20 @@ async def accept_tos(request: Request):
|
||||
redirect_url = body.get('redirect_url', str(request.base_url))
|
||||
|
||||
# Update user settings with TOS acceptance
|
||||
accepted_tos: datetime = datetime.now(timezone.utc)
|
||||
with session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.accepted_tos = datetime.now(timezone.utc)
|
||||
session.merge(user_settings)
|
||||
else:
|
||||
# Create user settings if they don't exist
|
||||
user_settings = UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
accepted_tos=datetime.now(timezone.utc),
|
||||
user_version=0, # This will trigger a migration to the latest version on next load
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
session.rollback()
|
||||
logger.error('User for {user_id} not found.')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User does not exist'},
|
||||
)
|
||||
session.add(user_settings)
|
||||
|
||||
user.accepted_tos = accepted_tos
|
||||
session.commit()
|
||||
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
# Track user signup completion in PostHog
|
||||
track_user_signup_completed(
|
||||
user_id=user_id,
|
||||
signup_timestamp=user_settings.accepted_tos.isoformat(),
|
||||
)
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content={'redirect_url': redirect_url}
|
||||
|
||||
@@ -2,33 +2,23 @@
|
||||
import typing
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
import stripe
|
||||
from dateutil.relativedelta import relativedelta # type: ignore
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.config import get_config
|
||||
from server.constants import (
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
STRIPE_API_KEY,
|
||||
STRIPE_WEBHOOK_SECRET,
|
||||
SUBSCRIPTION_PRICE_DATA,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
from openhands.utils.posthog_tracker import track_credits_purchased
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
@@ -65,23 +55,10 @@ def validate_saas_environment(request: Request) -> None:
|
||||
)
|
||||
|
||||
|
||||
class BillingSessionType(Enum):
|
||||
DIRECT_PAYMENT = 'DIRECT_PAYMENT'
|
||||
MONTHLY_SUBSCRIPTION = 'MONTHLY_SUBSCRIPTION'
|
||||
|
||||
|
||||
class GetCreditsResponse(BaseModel):
|
||||
credits: Decimal | None = None
|
||||
|
||||
|
||||
class SubscriptionAccessResponse(BaseModel):
|
||||
start_at: datetime
|
||||
end_at: datetime
|
||||
created_at: datetime
|
||||
cancelled_at: datetime | None = None
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
amount: int
|
||||
|
||||
@@ -112,117 +89,23 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
|
||||
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
|
||||
if not stripe_service.STRIPE_API_KEY:
|
||||
return GetCreditsResponse()
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
user_json = await _get_litellm_user(client, user_id)
|
||||
credits = calculate_credits(user_json['user_info'])
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(user.current_org_id)
|
||||
)
|
||||
# Update to use calculate_credits
|
||||
spend = user_team_info.get('spend', 0)
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
|
||||
credits = max(max_budget - spend, 0)
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current subscription access
|
||||
@billing_router.get('/subscription-access')
|
||||
async def get_subscription_access(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> SubscriptionAccessResponse | None:
|
||||
"""Get details of the currently valid subscription for the user."""
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.first()
|
||||
)
|
||||
if not subscription_access:
|
||||
return None
|
||||
return SubscriptionAccessResponse(
|
||||
start_at=subscription_access.start_at,
|
||||
end_at=subscription_access.end_at,
|
||||
created_at=subscription_access.created_at,
|
||||
cancelled_at=subscription_access.cancelled_at,
|
||||
stripe_subscription_id=subscription_access.stripe_subscription_id,
|
||||
)
|
||||
|
||||
|
||||
# Endpoint to check if a user has entered a payment method into stripe
|
||||
@billing_router.post('/has-payment-method')
|
||||
async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await stripe_service.has_payment_method(user_id)
|
||||
|
||||
|
||||
# Endpoint to cancel user's subscription
|
||||
@billing_router.post('/cancel-subscription')
|
||||
async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONResponse:
|
||||
"""Cancel user's active subscription at the end of the current billing period."""
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
with session_maker() as session:
|
||||
# Find the user's active subscription
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not already cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription_access:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='No active subscription found',
|
||||
)
|
||||
|
||||
if not subscription_access.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot cancel subscription: missing Stripe subscription ID',
|
||||
)
|
||||
|
||||
try:
|
||||
# Cancel the subscription in Stripe at period end
|
||||
await stripe.Subscription.modify_async(
|
||||
subscription_access.stripe_subscription_id, cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Update local database
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
'end_at': subscription_access.end_at,
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{'status': 'success', 'message': 'Subscription cancelled successfully'}
|
||||
)
|
||||
|
||||
except stripe.StripeError as e:
|
||||
logger.error(
|
||||
'stripe_cancellation_failed',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f'Failed to cancel subscription: {str(e)}',
|
||||
)
|
||||
return await stripe_service.has_payment_method_by_user_id(user_id)
|
||||
|
||||
|
||||
# Endpoint to create a new setup intent in stripe
|
||||
@@ -231,16 +114,15 @@ async def create_customer_setup_session(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
|
||||
# Endpoint to create a new Stripe checkout session for credit purchase
|
||||
@@ -252,9 +134,9 @@ async def create_checkout_session(
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
customer=customer_info['customer_id'],
|
||||
line_items=[
|
||||
{
|
||||
'price_data': {
|
||||
@@ -267,7 +149,7 @@ async def create_checkout_session(
|
||||
'tax_behavior': 'exclusive',
|
||||
},
|
||||
'quantity': 1,
|
||||
}
|
||||
},
|
||||
],
|
||||
mode='payment',
|
||||
payment_method_types=['card'],
|
||||
@@ -280,8 +162,9 @@ async def create_checkout_session(
|
||||
logger.info(
|
||||
'created_stripe_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'stripe_customer_id': customer_info['customer_id'],
|
||||
'user_id': user_id,
|
||||
'org_id': customer_info['org_id'],
|
||||
'amount': body.amount,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
},
|
||||
@@ -290,105 +173,14 @@ async def create_checkout_session(
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
org_id=customer_info['org_id'],
|
||||
price=body.amount,
|
||||
price_code='NA',
|
||||
billing_session_type=BillingSessionType.DIRECT_PAYMENT.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@billing_router.post('/subscription-checkout-session')
|
||||
async def create_subscription_checkout_session(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
# Prevent duplicate subscriptions for the same user
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
existing_active_subscription = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_active_subscription:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot create subscription: User already has an active subscription that has not been cancelled',
|
||||
)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
subscription_price_data = SUBSCRIPTION_PRICE_DATA[billing_session_type.value]
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
line_items=[
|
||||
{
|
||||
'price_data': subscription_price_data,
|
||||
'quantity': 1,
|
||||
}
|
||||
],
|
||||
mode='subscription',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
subscription_data={
|
||||
'metadata': {
|
||||
'user_id': user_id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_subscription_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'user_id': user_id,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
price=subscription_price_data['unit_amount'],
|
||||
price_code='NA',
|
||||
billing_session_type=billing_session_type.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(
|
||||
redirect_url=typing.cast(str, checkout_session.url)
|
||||
)
|
||||
|
||||
|
||||
@billing_router.get('/create-subscription-checkout-session')
|
||||
async def create_subscription_checkout_session_via_get(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> RedirectResponse:
|
||||
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
|
||||
validate_saas_environment(request)
|
||||
|
||||
response = await create_subscription_checkout_session(
|
||||
request, billing_session_type, user_id
|
||||
)
|
||||
return RedirectResponse(response.redirect_url)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
|
||||
# Callback endpoint for successful Stripe payments - updates user credits and billing session status
|
||||
@@ -410,15 +202,6 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Any non direct payment (Subscription) is processed in the invoice_payment.paid by the webhook
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
!= BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
stripe_session = stripe.checkout.Session.retrieve(session_id)
|
||||
if stripe_session.status != 'complete':
|
||||
# Hopefully this never happens - we get a redirect from stripe where the payment is not yet complete
|
||||
@@ -432,45 +215,39 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
# Update max budget in litellm
|
||||
user_json = await _get_litellm_user(client, billing_session.user_id)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
new_max_budget = (
|
||||
(user_json.get('user_info') or {}).get('max_budget') or 0
|
||||
) + add_credits
|
||||
await _upsert_litellm_user(client, billing_session.user_id, new_max_budget)
|
||||
user = await call_sync_from_async(
|
||||
UserStore.get_user_by_id, billing_session.user_id
|
||||
)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = amount_subtotal
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
str(user.current_org_id), new_max_budget
|
||||
)
|
||||
|
||||
# Track credits purchased in PostHog
|
||||
try:
|
||||
track_credits_purchased(
|
||||
user_id=billing_session.user_id,
|
||||
amount_usd=amount_subtotal / 100, # Convert cents to dollars
|
||||
credits_added=add_credits,
|
||||
stripe_session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Failed to track credits purchase: {e}',
|
||||
extra={'user_id': billing_session.user_id, 'error': str(e)},
|
||||
)
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = add_credits
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'org_id': str(user.current_org_id),
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=success', status_code=302
|
||||
@@ -500,206 +277,6 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
session.merge(billing_session)
|
||||
session.commit()
|
||||
|
||||
# Redirect credit purchases to billing screen, subscriptions to LLM settings
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
== BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=cancel',
|
||||
status_code=302,
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
# If no billing session found, default to LLM settings (subscription flow)
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
f'{request.base_url}settings/billing?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
|
||||
@billing_router.post('/stripe-webhook')
|
||||
async def stripe_webhook(request: Request) -> JSONResponse:
|
||||
"""Endpoint for stripe webhooks."""
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get('stripe-signature')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise HTTPException(status_code=400, detail=f'Invalid payload: {e}')
|
||||
except stripe.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise HTTPException(status_code=400, detail=f'Invalid signature: {e}')
|
||||
|
||||
# Handle the event
|
||||
logger.info('stripe_webhook_event', extra={'event': event})
|
||||
event_type = event['type']
|
||||
if event_type == 'invoice.paid':
|
||||
invoice = event['data']['object']
|
||||
amount_paid = invoice.amount_paid
|
||||
metadata = invoice.parent.subscription_details.metadata # type: ignore
|
||||
billing_session_type = metadata.billing_session_type
|
||||
assert (
|
||||
amount_paid == SUBSCRIPTION_PRICE_DATA[billing_session_type]['unit_amount']
|
||||
)
|
||||
user_id = metadata.user_id
|
||||
|
||||
start_at = datetime.now(UTC)
|
||||
if billing_session_type == BillingSessionType.MONTHLY_SUBSCRIPTION.value:
|
||||
end_at = start_at + relativedelta(months=1)
|
||||
else:
|
||||
raise ValueError(f'unknown_billing_session_type:{billing_session_type}')
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = SubscriptionAccess(
|
||||
status='ACTIVE',
|
||||
user_id=user_id,
|
||||
start_at=start_at,
|
||||
end_at=end_at,
|
||||
amount_paid=amount_paid,
|
||||
stripe_invoice_payment_id=invoice.payment_intent,
|
||||
stripe_subscription_id=invoice.subscription, # Store Stripe subscription ID
|
||||
)
|
||||
session.add(subscription_access)
|
||||
session.commit()
|
||||
elif event_type == 'customer.subscription.updated':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
# Handle subscription cancellation
|
||||
if subscription.get('cancel_at_period_end') is True:
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(
|
||||
SubscriptionAccess.stripe_subscription_id == subscription_id
|
||||
)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access and not subscription_access.cancelled_at:
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled_via_webhook',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
elif event_type == 'customer.subscription.deleted':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.stripe_subscription_id == subscription_id)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access:
|
||||
subscription_access.status = 'DISABLED'
|
||||
subscription_access.updated_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
# Reset user settings to free tier defaults
|
||||
reset_user_to_free_tier_settings(subscription_access.user_id)
|
||||
|
||||
logger.info(
|
||||
'subscription_expired_reset_to_free_tier',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info('stripe_webhook_unhandled_event_type', extra={'type': event_type})
|
||||
|
||||
return JSONResponse({'status': 'success'})
|
||||
|
||||
|
||||
def reset_user_to_free_tier_settings(user_id: str) -> None:
|
||||
"""Reset user settings to free tier defaults when subscription ends."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.llm_model = get_default_litellm_model()
|
||||
user_settings.llm_api_key = None
|
||||
user_settings.llm_api_key_for_byor = None
|
||||
user_settings.llm_base_url = LITE_LLM_API_URL
|
||||
user_settings.max_budget_per_task = None
|
||||
user_settings.confirmation_mode = False
|
||||
user_settings.enable_solvability_analysis = False
|
||||
user_settings.security_analyzer = 'llm'
|
||||
user_settings.agent = 'CodeActAgent'
|
||||
user_settings.language = 'en'
|
||||
user_settings.enable_default_condenser = True
|
||||
user_settings.enable_sound_notifications = False
|
||||
user_settings.enable_proactive_conversation_starters = True
|
||||
user_settings.user_consents_to_analytics = False
|
||||
|
||||
session.merge(user_settings)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_settings_reset_to_free_tier',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'reset_timestamp': datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _get_litellm_user(client: httpx.AsyncClient, user_id: str) -> dict:
|
||||
"""Get a user from litellm with the id matching that given.
|
||||
|
||||
If no such user exists, returns a dummy user in the format:
|
||||
`{'user_id': '<USER_ID>', 'user_info': {'spend': 0}, 'keys': [], 'teams': []}`
|
||||
"""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
async def _upsert_litellm_user(
|
||||
client: httpx.AsyncClient, user_id: str, max_budget: float
|
||||
):
|
||||
"""Insert / Update a user in litellm."""
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
json={
|
||||
'user_id': user_id,
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -6,7 +6,7 @@ from threading import Thread
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from sqlalchemy import func, select
|
||||
from storage.database import a_session_maker, engine, session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user import User
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import wait_all
|
||||
@@ -127,7 +127,7 @@ def _db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
with session_maker() as session:
|
||||
num_users = session.query(UserSettings).count()
|
||||
num_users = session.query(User).count()
|
||||
time.sleep(delay)
|
||||
logger.info(
|
||||
'check',
|
||||
@@ -155,7 +155,7 @@ async def _a_db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
async with a_session_maker() as a_session:
|
||||
stmt = select(func.count(UserSettings.id))
|
||||
stmt = select(func.count(User.id))
|
||||
num_users = await a_session.execute(stmt)
|
||||
await asyncio.sleep(delay)
|
||||
logger.info(f'a_num_users:{num_users.scalar_one()}')
|
||||
|
||||
@@ -21,7 +21,7 @@ from server.utils.conversation_callback_utils import (
|
||||
update_conversation_stats,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.server.shared import conversation_manager
|
||||
|
||||
@@ -226,12 +226,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]:
|
||||
|
||||
def _get_user_id(conversation_id: str) -> str:
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
return conversation_metadata.user_id
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
|
||||
|
||||
async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None:
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.database import session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.server.shared import file_store
|
||||
@@ -33,10 +33,10 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
def _verify_conversation():
|
||||
with session_maker() as session:
|
||||
metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
@@ -58,7 +59,8 @@ async def github_events(
|
||||
)
|
||||
|
||||
try:
|
||||
payload = await request.body()
|
||||
# Add timeout to prevent hanging on slow/stalled clients
|
||||
payload = await asyncio.wait_for(request.body(), timeout=15.0)
|
||||
verify_github_signature(payload, x_hub_signature_256)
|
||||
|
||||
payload_data = await request.json()
|
||||
@@ -78,6 +80,12 @@ async def github_events(
|
||||
status_code=200,
|
||||
content={'message': 'GitHub events endpoint reached successfully.'},
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning('GitHub webhook request timed out waiting for request body')
|
||||
return JSONResponse(
|
||||
status_code=408,
|
||||
content={'error': 'Request timeout - client took too long to send data.'},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error processing GitHub event: {e}')
|
||||
return JSONResponse(status_code=400, content={'error': 'Invalid payload.'})
|
||||
|
||||
@@ -15,7 +15,6 @@ from integrations.slack.slack_manager import SlackManager
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import (
|
||||
KEYCLOAK_CLIENT_ID,
|
||||
KEYCLOAK_REALM_NAME,
|
||||
@@ -35,9 +34,11 @@ from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
from storage.slack_user import SlackUser
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.shared import config, sio
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
signature_verifier = SignatureVerifier(signing_secret=SLACK_SIGNING_SECRET)
|
||||
slack_router = APIRouter(prefix='/slack')
|
||||
@@ -79,6 +80,14 @@ async def install_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if not config.jwt_secret:
|
||||
logger.error('slack_install_callback_error JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
try:
|
||||
client = AsyncWebClient() # no prepared token needed for this
|
||||
# Complete the installation by calling oauth.v2.access API method
|
||||
@@ -94,16 +103,17 @@ async def install_callback(
|
||||
|
||||
# Create a state variable for keycloak oauth
|
||||
payload = {}
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if state:
|
||||
payload = jwt.decode(
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
payload['slack_user_id'] = authed_user.get('id')
|
||||
payload['bot_access_token'] = bot_access_token
|
||||
payload['team_id'] = team_id
|
||||
|
||||
state = jwt.encode(payload, jwt_secret.get_secret_value(), algorithm='HS256')
|
||||
state = jwt.encode(
|
||||
payload, config.jwt_secret.get_secret_value(), algorithm='HS256'
|
||||
)
|
||||
|
||||
# Redirect into keycloak
|
||||
scope = quote('openid email profile offline_access')
|
||||
@@ -149,9 +159,16 @@ async def keycloak_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if not config.jwt_secret:
|
||||
logger.error('problem_retrieving_keycloak_tokens JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
payload: dict[str, str] = jwt.decode(
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
slack_user_id = payload['slack_user_id']
|
||||
bot_access_token = payload['bot_access_token']
|
||||
@@ -180,6 +197,13 @@ async def keycloak_callback(
|
||||
|
||||
user_info = await token_manager.get_user_info(keycloak_access_token)
|
||||
keycloak_user_id = user_info['sub']
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, keycloak_user_id)
|
||||
if not user:
|
||||
return _html_response(
|
||||
title='Failed to authenticate.',
|
||||
description=f'Please re-login into <a href="{HOST_URL}" style="color:#ecedee;text-decoration:underline;">OpenHands Cloud</a>. Then try <a href="https://docs.all-hands.dev/usage/cloud/slack-installation" style="color:#ecedee;text-decoration:underline;">installing the OpenHands Slack App</a> again',
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# These tokens are offline access tokens - store them!
|
||||
await token_manager.store_offline_token(keycloak_user_id, keycloak_refresh_token)
|
||||
@@ -211,6 +235,7 @@ async def keycloak_callback(
|
||||
slack_display_name = slack_user_info.data['user']['profile']['display_name']
|
||||
slack_user = SlackUser(
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
org_id=user.current_org_id,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_display_name=slack_display_name,
|
||||
)
|
||||
@@ -305,7 +330,7 @@ async def on_form_interaction(request: Request, background_tasks: BackgroundTask
|
||||
|
||||
body = await request.body()
|
||||
form = await request.form()
|
||||
payload = json.loads(form.get('payload')) # type: ignore[arg-type]
|
||||
payload = json.loads(form.get('payload'))
|
||||
|
||||
logger.info('slack_on_form_interaction', extra={'payload': payload})
|
||||
|
||||
|
||||
@@ -20,7 +20,10 @@ from server.utils.conversation_callback_utils import (
|
||||
from sqlalchemy import orm
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
@@ -70,6 +73,11 @@ RUNTIME_CONVERSATION_URL = RUNTIME_URL_PATTERN + (
|
||||
else '/api/conversations/{conversation_id}'
|
||||
)
|
||||
|
||||
RUNTIME_USERNAME = os.getenv('RUNTIME_USERNAME')
|
||||
SU_TO_USER = os.getenv('SU_TO_USER', 'false')
|
||||
truthy = {'1', 'true', 't', 'yes', 'y', 'on'}
|
||||
SU_TO_USER = str(SU_TO_USER.lower() in truthy).lower()
|
||||
|
||||
# Time in seconds before a Redis entry is considered expired if not refreshed
|
||||
_REDIS_ENTRY_TIMEOUT_SECONDS = 300
|
||||
|
||||
@@ -525,16 +533,18 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
"""
|
||||
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation_metadata:
|
||||
if not conversation_metadata_saas:
|
||||
raise ValueError(f'No conversation found {conversation_id}')
|
||||
|
||||
return conversation_metadata.user_id
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
|
||||
async def _get_runtime_status_from_nested_runtime(
|
||||
self, session_api_key: Any | None, nested_url: str, conversation_id: str
|
||||
@@ -772,7 +782,11 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
env_vars['SERVE_FRONTEND'] = '0'
|
||||
env_vars['RUNTIME'] = 'local'
|
||||
# TODO: In the long term we may come up with a more secure strategy for user management within the nested runtime.
|
||||
env_vars['USER'] = 'openhands' if config.run_as_openhands else 'root'
|
||||
env_vars['USER'] = (
|
||||
RUNTIME_USERNAME
|
||||
if RUNTIME_USERNAME
|
||||
else ('openhands' if config.run_as_openhands else 'root')
|
||||
)
|
||||
env_vars['PERMITTED_CORS_ORIGINS'] = ','.join(PERMITTED_CORS_ORIGINS)
|
||||
env_vars['port'] = '60000'
|
||||
# TODO: These values are static in the runtime-api project, but do not get copied into the runtime ENV
|
||||
@@ -789,6 +803,7 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
env_vars['INITIAL_NUM_WARM_SERVERS'] = '1'
|
||||
env_vars['INIT_GIT_IN_EMPTY_WORKSPACE'] = '1'
|
||||
env_vars['ENABLE_V1'] = '0'
|
||||
env_vars['SU_TO_USER'] = SU_TO_USER
|
||||
|
||||
# We need this for LLM traces tracking to identify the source of the LLM calls
|
||||
env_vars['WEB_HOST'] = WEB_HOST
|
||||
@@ -858,9 +873,17 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
with session_maker() as session:
|
||||
# Only include conversations updated in the past week
|
||||
one_week_ago = datetime.now(UTC) - timedelta(days=7)
|
||||
query = session.query(StoredConversationMetadata.conversation_id).filter(
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
query = (
|
||||
session.query(StoredConversationMetadata.conversation_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
)
|
||||
)
|
||||
user_conversation_ids = set(query)
|
||||
return user_conversation_ids
|
||||
@@ -934,11 +957,16 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
if conversation_metadata is None:
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
if conversation_metadata is None or conversation_metadata_saas is None:
|
||||
# Conversation is running in different server
|
||||
return
|
||||
|
||||
user_id = conversation_metadata.user_id
|
||||
user_id = conversation_metadata_saas.user_id
|
||||
|
||||
# Get the id of the next event which is not present
|
||||
events_dir = get_conversation_events_dir(
|
||||
|
||||
@@ -11,7 +11,6 @@ from storage.conversation_callback import (
|
||||
)
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -126,6 +125,12 @@ def update_conversation_metadata(conversation_id: str, content: dict):
|
||||
conversation_id: The conversation ID to update
|
||||
content: The metadata content to update
|
||||
"""
|
||||
|
||||
# Local import fixes the lazy-loading problem
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'update_conversation_metadata',
|
||||
extra={
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
from storage.api_key import ApiKey
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.conversation_callback import CallbackStatus, ConversationCallback
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.experiment_assignment import ExperimentAssignment
|
||||
from storage.feedback import ConversationFeedback, Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_dc_conversation import JiraDcConversation
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
from storage.linear_conversation import LinearConversation
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
from storage.role import Role
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_team import SlackTeam
|
||||
from storage.slack_user import SlackUser
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.subscription_access_status import SubscriptionAccessStatus
|
||||
from storage.user import User
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
__all__ = [
|
||||
'ApiKey',
|
||||
'AuthTokens',
|
||||
'BillingSession',
|
||||
'BillingSessionType',
|
||||
'CallbackStatus',
|
||||
'ConversationCallback',
|
||||
'ConversationFeedback',
|
||||
'StoredConversationMetadataSaas',
|
||||
'ConversationWork',
|
||||
'ExperimentAssignment',
|
||||
'Feedback',
|
||||
'GithubAppInstallation',
|
||||
'GitlabWebhook',
|
||||
'JiraConversation',
|
||||
'JiraDcConversation',
|
||||
'JiraDcUser',
|
||||
'JiraDcWorkspace',
|
||||
'JiraUser',
|
||||
'JiraWorkspace',
|
||||
'LinearConversation',
|
||||
'LinearUser',
|
||||
'LinearWorkspace',
|
||||
'MaintenanceTask',
|
||||
'MaintenanceTaskStatus',
|
||||
'OpenhandsPR',
|
||||
'Org',
|
||||
'OrgMember',
|
||||
'ProactiveConversation',
|
||||
'Role',
|
||||
'SlackConversation',
|
||||
'SlackTeam',
|
||||
'SlackUser',
|
||||
'StoredConversationMetadata',
|
||||
'StoredOfflineToken',
|
||||
'StoredRepository',
|
||||
'StoredCustomSecrets',
|
||||
'StripeCustomer',
|
||||
'SubscriptionAccess',
|
||||
'SubscriptionAccessStatus',
|
||||
'User',
|
||||
'UserRepositoryMap',
|
||||
'UserSettings',
|
||||
'WebhookStatus',
|
||||
]
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -11,9 +13,13 @@ class ApiKey(Base):
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
key = Column(String(255), nullable=False, unique=True, index=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
name = Column(String(255), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
)
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='api_keys')
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.api_key import ApiKey
|
||||
from storage.database import session_maker
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -36,10 +37,15 @@ class ApiKeyStore:
|
||||
The generated API key
|
||||
"""
|
||||
api_key = self.generate_api_key()
|
||||
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key, user_id=user_id, name=name, expires_at=expires_at
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(key_record)
|
||||
session.commit()
|
||||
@@ -99,8 +105,15 @@ class ApiKeyStore:
|
||||
|
||||
def list_api_keys(self, user_id: str) -> list[dict]:
|
||||
"""List all API keys for a user."""
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys = session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
keys = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
@@ -115,9 +128,14 @@ class ApiKeyStore:
|
||||
]
|
||||
|
||||
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
for key in keys:
|
||||
if key.name == 'MCP_API_KEY':
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, String
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -11,9 +13,9 @@ class BillingSession(Base): # type: ignore
|
||||
"""
|
||||
|
||||
__tablename__ = 'billing_sessions'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
status = Column(
|
||||
Enum(
|
||||
'in_progress',
|
||||
@@ -24,15 +26,6 @@ class BillingSession(Base): # type: ignore
|
||||
),
|
||||
default='in_progress',
|
||||
)
|
||||
billing_session_type = Column(
|
||||
Enum(
|
||||
'DIRECT_PAYMENT',
|
||||
'MONTHLY_SUBSCRIPTION',
|
||||
name='billing_session_type_enum',
|
||||
),
|
||||
nullable=False,
|
||||
default='DIRECT_PAYMENT',
|
||||
)
|
||||
price = Column(DECIMAL(19, 4), nullable=False)
|
||||
price_code = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -43,3 +36,6 @@ class BillingSession(Base): # type: ignore
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC), # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='billing_sessions')
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
@@ -7,6 +8,9 @@ from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.util import await_only
|
||||
|
||||
# Check if we're running in a test environment
|
||||
IS_TESTING = 'pytest' in sys.modules
|
||||
|
||||
DB_HOST = os.environ.get('DB_HOST', 'localhost') # for non-GCP environments
|
||||
DB_PORT = os.environ.get('DB_PORT', '5432') # for non-GCP environments
|
||||
DB_USER = os.environ.get('DB_USER', 'postgres')
|
||||
|
||||
114
enterprise/storage/encrypt_utils.py
Normal file
114
enterprise/storage/encrypt_utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import binascii
|
||||
import hashlib
|
||||
from base64 import b64decode, b64encode
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from pydantic import SecretStr
|
||||
from server.config import get_config
|
||||
|
||||
_jwt_service = None
|
||||
_fernet = None
|
||||
|
||||
|
||||
def encrypt_model(encrypt_keys: list, model_instance) -> dict:
|
||||
return encrypt_kwargs(encrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def decrypt_model(decrypt_keys: list, model_instance) -> dict:
|
||||
return decrypt_kwargs(decrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def encrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if isinstance(value, dict):
|
||||
encrypt_kwargs(encrypt_keys, value)
|
||||
continue
|
||||
|
||||
if key in encrypt_keys:
|
||||
value = encrypt_value(value)
|
||||
kwargs[key] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def decrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = decrypt_value(value)
|
||||
kwargs[key] = value
|
||||
except binascii.Error:
|
||||
pass # Key is in legacy format...
|
||||
return kwargs
|
||||
|
||||
|
||||
def encrypt_value(value: str | SecretStr) -> str:
|
||||
return get_jwt_service().create_jwe_token(
|
||||
{'v': value.get_secret_value() if isinstance(value, SecretStr) else value}
|
||||
)
|
||||
|
||||
|
||||
def decrypt_value(value: str | SecretStr) -> str:
|
||||
token = get_jwt_service().decrypt_jwe_token(
|
||||
value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
)
|
||||
return token['v']
|
||||
|
||||
|
||||
def get_jwt_service():
|
||||
from openhands.app_server.config import get_global_config
|
||||
|
||||
global _jwt_service
|
||||
if _jwt_service is None:
|
||||
jwt_service_injector = get_global_config().jwt
|
||||
assert jwt_service_injector is not None
|
||||
_jwt_service = jwt_service_injector.get_jwt_service()
|
||||
return _jwt_service
|
||||
|
||||
|
||||
def decrypt_legacy_model(decrypt_keys: list, model_instance) -> dict:
|
||||
return decrypt_legacy_kwargs(decrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def decrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
value = decrypt_legacy_value(value)
|
||||
kwargs[key] = value
|
||||
except binascii.Error:
|
||||
pass # Key is in legacy format...
|
||||
except InvalidToken:
|
||||
pass # Key not encrypted...
|
||||
return kwargs
|
||||
|
||||
|
||||
def decrypt_legacy_value(value: str | SecretStr) -> str:
|
||||
if isinstance(value, SecretStr):
|
||||
return (
|
||||
get_fernet().decrypt(b64decode(value.get_secret_value().encode())).decode()
|
||||
)
|
||||
else:
|
||||
return get_fernet().decrypt(b64decode(value.encode())).decode()
|
||||
|
||||
|
||||
def get_fernet():
|
||||
global _fernet
|
||||
if _fernet is None:
|
||||
jwt_secret = get_config().jwt_secret.get_secret_value()
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
_fernet = Fernet(fernet_key)
|
||||
return _fernet
|
||||
|
||||
|
||||
def model_to_kwargs(model_instance):
|
||||
return {
|
||||
column.name: getattr(model_instance, column.name)
|
||||
for column in model_instance.__table__.columns
|
||||
}
|
||||
@@ -1,7 +1,16 @@
|
||||
import sys
|
||||
from enum import IntEnum
|
||||
|
||||
from sqlalchemy import ARRAY, Boolean, Column, DateTime, Integer, String, Text, text
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
|
||||
674
enterprise/storage/lite_llm_manager.py
Normal file
674
enterprise/storage/lite_llm_manager.py
Normal file
@@ -0,0 +1,674 @@
|
||||
"""
|
||||
Store class for managing organizational settings.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
class LiteLlmManager:
|
||||
"""Manage LiteLLM interactions."""
|
||||
|
||||
@staticmethod
|
||||
async def create_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
oss_settings: Settings,
|
||||
) -> Settings | None:
|
||||
logger.info(
|
||||
'SettingsStore:update_settings_with_litellm_default:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(keycloak_user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
)
|
||||
|
||||
await LiteLlmManager._create_user(
|
||||
client, keycloak_user_info.get('email'), keycloak_user_id
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
)
|
||||
|
||||
key = await LiteLlmManager._generate_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
f'OpenHands Cloud - user {keycloak_user_id}',
|
||||
None,
|
||||
)
|
||||
|
||||
oss_settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
oss_settings.llm_model = get_default_litellm_model()
|
||||
oss_settings.llm_api_key = SecretStr(key)
|
||||
oss_settings.llm_base_url = LITE_LLM_API_URL
|
||||
return oss_settings
|
||||
|
||||
@staticmethod
|
||||
async def migrate_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
user_settings: UserSettings,
|
||||
) -> UserSettings | None:
|
||||
logger.info(
|
||||
'SettingsStore:umigrate_lite_llm_entries:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
user_json = await LiteLlmManager._get_user(client, keycloak_user_id)
|
||||
if not user_json:
|
||||
return None
|
||||
user_info = user_json['user_info']
|
||||
max_budget = user_info.get('max_budget', 0.0)
|
||||
if not max_budget:
|
||||
# if max_budget is None, then we've already migrated the User
|
||||
return None
|
||||
spend = user_info.get('spend', 0.0)
|
||||
credits = max(max_budget - spend, 0.0)
|
||||
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
await LiteLlmManager._update_user(
|
||||
client, keycloak_user_id, max_budget=1000000000.0
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key:
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key,
|
||||
team_id=org_id,
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key_for_byor,
|
||||
team_id=org_id,
|
||||
)
|
||||
|
||||
user_settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
user_settings.llm_model = get_default_litellm_model()
|
||||
user_settings.llm_base_url = LITE_LLM_API_URL
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
async def update_team_and_users_budget(
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
await LiteLlmManager._update_team(client, team_id, None, max_budget)
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
return None
|
||||
for membership in team_info.get('team_memberships', []):
|
||||
user_id = membership.get('user_id')
|
||||
if not user_id:
|
||||
continue
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
client, user_id, team_id, max_budget
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _create_team(
|
||||
client: httpx.AsyncClient,
|
||||
team_alias: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/new',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'team_alias': team_alias,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': 0,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
# Team failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
response.status_code == 400
|
||||
and 'already exists. Please use a different team id' in response.text
|
||||
):
|
||||
# team already exists, so update, then return
|
||||
await LiteLlmManager._update_team(
|
||||
client, team_id, team_alias, max_budget
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_creating_litellm_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'team_id': team_id,
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_team(client: httpx.AsyncClient, team_id: str) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
"""Get a team from litellm with the id matching that given."""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/team/info?team_id={team_id}',
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
async def _update_team(
|
||||
client: httpx.AsyncClient,
|
||||
team_id: str,
|
||||
team_alias: str | None,
|
||||
max_budget: float | None,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget'] = max_budget
|
||||
|
||||
if team_alias is not None:
|
||||
json_data['team_alias'] = team_alias
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/update',
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
# Team failed to update in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _create_user(
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': email,
|
||||
'models': [],
|
||||
'user_id': keycloak_user_id,
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': False,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
'duplicate_user_email',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'email': email,
|
||||
},
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': None,
|
||||
'models': [],
|
||||
'user_id': keycloak_user_id,
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': False,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if response.status_code == 400 and 'already exists' in response.text:
|
||||
# user already exists, just return
|
||||
return
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'email': None,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_user(client: httpx.AsyncClient, user_id: str) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
"""Get a user from litellm with the id matching that given."""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
async def _update_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
**kwargs,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
payload = {
|
||||
'user_id': keycloak_user_id,
|
||||
}
|
||||
payload.update(kwargs)
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _update_key(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
key: str,
|
||||
**kwargs,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
payload = {
|
||||
'key': key,
|
||||
}
|
||||
payload.update(kwargs)
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/update',
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _delete_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [keycloak_user_id]}
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_deleting_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _add_user_to_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_add',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'member': {'user_id': keycloak_user_id, 'role': 'user'},
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
# Failed to add user to team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_adding_litellm_user_to_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_user_team_info(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
return None
|
||||
|
||||
# Filter team_memberships based on team_id and keycloak_user_id
|
||||
user_membership = next(
|
||||
(
|
||||
membership
|
||||
for membership in team_info.get('team_memberships', [])
|
||||
if membership.get('user_id') == keycloak_user_id
|
||||
and membership.get('team_id') == team_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return user_membership
|
||||
|
||||
@staticmethod
|
||||
async def _update_user_in_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_update',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
# Failed to update user in team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_user_in_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _generate_key(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str | None,
|
||||
key_alias: str | None,
|
||||
metadata: dict | None,
|
||||
) -> str | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
json_data: dict[str, Any] = {
|
||||
'user_id': keycloak_user_id,
|
||||
'models': [],
|
||||
}
|
||||
|
||||
if team_id is not None:
|
||||
json_data['team_id'] = team_id
|
||||
|
||||
if key_alias is not None:
|
||||
json_data['key_alias'] = key_alias
|
||||
|
||||
if metadata is not None:
|
||||
json_data['metadata'] = metadata
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json=json_data,
|
||||
)
|
||||
# Failed to generate user key for team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_generate_user_team_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
'key_alias': key_alias,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json['key']
|
||||
logger.info(
|
||||
'LiteLlmManager:_lite_llm_generate_user_team_key:key_created',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': team_id,
|
||||
'key_alias': key_alias,
|
||||
},
|
||||
)
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
async def _get_key_info(
|
||||
client: httpx.AsyncClient,
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
) -> dict | None:
|
||||
from storage.user_store import UserStore
|
||||
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, keycloak_user_id)
|
||||
if not user:
|
||||
return {}
|
||||
|
||||
org_member = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return {}
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/key/info?key={org_member.llm_api_key}'
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key_info = response_json.get('info')
|
||||
if not key_info:
|
||||
return {}
|
||||
return {
|
||||
'key_max_budget': key_info.get('max_budget'),
|
||||
'key_spend': key_info.get('spend'),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _delete_key(
|
||||
client: httpx.AsyncClient,
|
||||
key_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/delete',
|
||||
json={
|
||||
'keys': [key_id],
|
||||
},
|
||||
)
|
||||
# Failed to key...
|
||||
if not response.is_success:
|
||||
if response.status_code == 404:
|
||||
# key doesn't exist, just return
|
||||
return
|
||||
logger.error(
|
||||
'error_deleting_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
'LiteLlmManager:_delete_key:key_deleted',
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def with_http_client(
|
||||
internal_fn: Callable[..., Awaitable[Any]],
|
||||
) -> Callable[..., Awaitable[Any]]:
|
||||
@functools.wraps(internal_fn)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with httpx.AsyncClient(
|
||||
headers={'x-goog-api-key': LITE_LLM_API_KEY}
|
||||
) as client:
|
||||
return await internal_fn(client, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
# Public methods with injected client
|
||||
create_team = staticmethod(with_http_client(_create_team))
|
||||
get_team = staticmethod(with_http_client(_get_team))
|
||||
update_team = staticmethod(with_http_client(_update_team))
|
||||
create_user = staticmethod(with_http_client(_create_user))
|
||||
get_user = staticmethod(with_http_client(_get_user))
|
||||
update_user = staticmethod(with_http_client(_update_user))
|
||||
delete_user = staticmethod(with_http_client(_delete_user))
|
||||
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
|
||||
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
|
||||
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
|
||||
generate_key = staticmethod(with_http_client(_generate_key))
|
||||
get_key_info = staticmethod(with_http_client(_get_key_info))
|
||||
delete_key = staticmethod(with_http_client(_delete_key))
|
||||
117
enterprise/storage/org.py
Normal file
117
enterprise/storage/org.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import SecretStr
|
||||
from server.constants import DEFAULT_BILLING_MARGIN
|
||||
from sqlalchemy import JSON, UUID, Boolean, Column, Float, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class Org(Base): # type: ignore
|
||||
"""Organization model."""
|
||||
|
||||
__tablename__ = 'org'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
contact_name = Column(String, nullable=True)
|
||||
contact_email = Column(String, nullable=True)
|
||||
agent = Column(String, nullable=True)
|
||||
default_max_iterations = Column(Integer, nullable=True)
|
||||
security_analyzer = Column(String, nullable=True)
|
||||
confirmation_mode = Column(Boolean, nullable=True, default=False)
|
||||
default_llm_model = Column(String, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_default_llm_api_key_for_byor = Column(String, nullable=True)
|
||||
default_llm_base_url = Column(String, nullable=True)
|
||||
remote_runtime_resource_factor = Column(Integer, nullable=True)
|
||||
enable_default_condenser = Column(Boolean, nullable=False, default=True)
|
||||
billing_margin = Column(Float, nullable=True, default=DEFAULT_BILLING_MARGIN)
|
||||
enable_proactive_conversation_starters = Column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
sandbox_base_container_image = Column(String, nullable=True)
|
||||
sandbox_runtime_container_image = Column(String, nullable=True)
|
||||
org_version = Column(Integer, nullable=False, default=0)
|
||||
mcp_config = Column(JSON, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_search_api_key = Column(String, nullable=True)
|
||||
# encrypted column, don't set directly, set without the underscore
|
||||
_sandbox_api_key = Column(String, nullable=True)
|
||||
max_budget_per_task = Column(Float, nullable=True)
|
||||
enable_solvability_analysis = Column(Boolean, nullable=True, default=False)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
conversation_expiration = Column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org_members = relationship('OrgMember', back_populates='org')
|
||||
current_users = relationship('User', back_populates='current_org')
|
||||
billing_sessions = relationship('BillingSession', back_populates='org')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='org'
|
||||
)
|
||||
user_secrets = relationship('StoredCustomSecrets', back_populates='org')
|
||||
api_keys = relationship('ApiKey', back_populates='org')
|
||||
slack_conversations = relationship('SlackConversation', back_populates='org')
|
||||
slack_users = relationship('SlackUser', back_populates='org')
|
||||
stripe_customers = relationship('StripeCustomer', back_populates='org')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'default_llm_api_key_for_byor' in kwargs:
|
||||
self.default_llm_api_key_for_byor = kwargs.pop(
|
||||
'default_llm_api_key_for_byor'
|
||||
)
|
||||
if 'search_api_key' in kwargs:
|
||||
self.search_api_key = kwargs.pop('search_api_key')
|
||||
if 'sandbox_api_key' in kwargs:
|
||||
self.sandbox_api_key = kwargs.pop('sandbox_api_key')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def default_llm_api_key_for_byor(self) -> SecretStr | None:
|
||||
if self._default_llm_api_key_for_byor:
|
||||
decrypted = decrypt_value(self._default_llm_api_key_for_byor)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@default_llm_api_key_for_byor.setter
|
||||
def default_llm_api_key_for_byor(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._default_llm_api_key_for_byor = encrypt_value(raw) if raw else None
|
||||
|
||||
@property
|
||||
def search_api_key(self) -> SecretStr | None:
|
||||
if self._search_api_key:
|
||||
decrypted = decrypt_value(self._search_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@search_api_key.setter
|
||||
def search_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._search_api_key = encrypt_value(raw) if raw else None
|
||||
|
||||
@property
|
||||
def sandbox_api_key(self) -> SecretStr | None:
|
||||
if self._sandbox_api_key:
|
||||
decrypted = decrypt_value(self._sandbox_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@sandbox_api_key.setter
|
||||
def sandbox_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._sandbox_api_key = encrypt_value(raw) if raw else None
|
||||
67
enterprise/storage/org_member.py
Normal file
67
enterprise/storage/org_member.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization-Member relationship.
|
||||
"""
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import UUID, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class OrgMember(Base): # type: ignore
|
||||
"""Junction table for organization-member relationships with roles."""
|
||||
|
||||
__tablename__ = 'org_member'
|
||||
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), primary_key=True)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), primary_key=True)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
|
||||
_llm_api_key = Column(String, nullable=False)
|
||||
max_iterations = Column(Integer, nullable=True)
|
||||
llm_model = Column(String, nullable=True)
|
||||
_llm_api_key_for_byor = Column(String, nullable=True)
|
||||
llm_base_url = Column(String, nullable=True)
|
||||
status = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='org_members')
|
||||
user = relationship('User', back_populates='org_members')
|
||||
role = relationship('Role', back_populates='org_members')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'llm_api_key' in kwargs:
|
||||
self.llm_api_key = kwargs.pop('llm_api_key')
|
||||
if 'llm_api_key_for_byor' in kwargs:
|
||||
self.llm_api_key_for_byor = kwargs.pop('llm_api_key_for_byor')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def llm_api_key(self) -> SecretStr:
|
||||
decrypted = decrypt_value(self._llm_api_key)
|
||||
return SecretStr(decrypted)
|
||||
|
||||
@llm_api_key.setter
|
||||
def llm_api_key(self, value: str | SecretStr):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key = encrypt_value(raw)
|
||||
|
||||
@property
|
||||
def llm_api_key_for_byor(self) -> SecretStr | None:
|
||||
if self._llm_api_key_for_byor:
|
||||
decrypted = decrypt_value(self._llm_api_key_for_byor)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@llm_api_key_for_byor.setter
|
||||
def llm_api_key_for_byor(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key_for_byor = encrypt_value(raw) if raw else None
|
||||
125
enterprise/storage/org_member_store.py
Normal file
125
enterprise/storage/org_member_store.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Store class for managing organization-member relationships.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from storage.database import session_maker
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgMemberStore:
|
||||
"""Store for managing organization-member relationships."""
|
||||
|
||||
@staticmethod
|
||||
def add_user_to_org(
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
role_id: int,
|
||||
llm_api_key: str,
|
||||
status: Optional[str] = None,
|
||||
) -> OrgMember:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
with session_maker() as session:
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
llm_api_key=llm_api_key,
|
||||
status=status,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: int) -> list[OrgMember]:
|
||||
"""Get all organizations for a user."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.user_id == user_id).all()
|
||||
|
||||
@staticmethod
|
||||
def get_org_members(org_id: UUID) -> list[OrgMember]:
|
||||
"""Get all users in an organization."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.org_id == org_id).all()
|
||||
|
||||
@staticmethod
|
||||
def update_org_member(org_member: OrgMember) -> None:
|
||||
"""Update an organization-member relationship."""
|
||||
with session_maker() as session:
|
||||
session.merge(org_member)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_user_role_in_org(
|
||||
org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None
|
||||
) -> Optional[OrgMember]:
|
||||
"""Update user's role in an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
org_member.role_id = role_id
|
||||
if status is not None:
|
||||
org_member.status = status
|
||||
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def remove_user_from_org(org_id: UUID, user_id: int) -> bool:
|
||||
"""Remove a user from an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return False
|
||||
|
||||
session.delete(org_member)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {
|
||||
normalized: getattr(settings, normalized)
|
||||
for c in OrgMember.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {
|
||||
normalized: getattr(user_settings, normalized)
|
||||
for c in OrgMember.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
139
enterprise/storage/org_store.py
Normal file
139
enterprise/storage/org_store.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Store class for managing organizations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgStore:
|
||||
"""Store for managing organizations."""
|
||||
|
||||
@staticmethod
|
||||
def create_org(
|
||||
kwargs: dict,
|
||||
) -> Org:
|
||||
"""Create a new organization."""
|
||||
with session_maker() as session:
|
||||
org = Org(**kwargs)
|
||||
org.org_version = ORG_SETTINGS_VERSION
|
||||
org.default_llm_model = get_default_litellm_model()
|
||||
session.add(org)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_id(org_id: UUID) -> Org | None:
|
||||
"""Get organization by ID."""
|
||||
with session_maker() as session:
|
||||
return session.query(Org).filter(Org.id == org_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_current_org_from_keycloak_user_id(keycloak_user_id: str) -> Org | None:
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == UUID(keycloak_user_id))
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
logger.warning(f'User not found for ID {keycloak_user_id}')
|
||||
return None
|
||||
org_id = user.current_org_id
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
f'Org not found for ID {org_id} as the current org for user {keycloak_user_id}'
|
||||
)
|
||||
return None
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_name(name: str) -> Org | None:
|
||||
"""Get organization by name."""
|
||||
with session_maker() as session:
|
||||
return session.query(Org).filter(Org.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
def list_orgs() -> list[Org]:
|
||||
"""List all organizations."""
|
||||
with session_maker() as session:
|
||||
orgs = session.query(Org).all()
|
||||
return orgs
|
||||
|
||||
@staticmethod
|
||||
def update_org(
|
||||
org_id: UUID,
|
||||
kwargs: dict,
|
||||
) -> Optional[Org]:
|
||||
"""Update organization details."""
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
if 'id' in kwargs:
|
||||
kwargs.pop('id')
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {}
|
||||
|
||||
for c in Org.__table__.columns:
|
||||
# Normalize for lookup
|
||||
normalized = (
|
||||
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
|
||||
)
|
||||
|
||||
if not hasattr(settings, normalized):
|
||||
continue
|
||||
|
||||
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
|
||||
key = c.name
|
||||
if key.startswith('_'):
|
||||
key = key[1:] # remove only the very first leading underscore
|
||||
|
||||
kwargs[key] = getattr(settings, normalized)
|
||||
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {}
|
||||
|
||||
for c in Org.__table__.columns:
|
||||
# Normalize for lookup
|
||||
normalized = (
|
||||
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
|
||||
)
|
||||
|
||||
if not hasattr(user_settings, normalized):
|
||||
continue
|
||||
|
||||
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
|
||||
key = c.name
|
||||
if key.startswith('_'):
|
||||
key = key[1:] # remove only the very first leading underscore
|
||||
|
||||
kwargs[key] = getattr(user_settings, normalized)
|
||||
|
||||
return kwargs
|
||||
21
enterprise/storage/role.py
Normal file
21
enterprise/storage/role.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
SQLAlchemy model for Role.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class Role(Base): # type: ignore
|
||||
"""Role model for user permissions."""
|
||||
|
||||
__tablename__ = 'role'
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
rank = Column(Integer, nullable=False)
|
||||
|
||||
# Relationships
|
||||
users = relationship('User', back_populates='role')
|
||||
org_members = relationship('OrgMember', back_populates='role')
|
||||
40
enterprise/storage/role_store.py
Normal file
40
enterprise/storage/role_store.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Store class for managing roles.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from storage.database import session_maker
|
||||
from storage.role import Role
|
||||
|
||||
|
||||
class RoleStore:
|
||||
"""Store for managing roles."""
|
||||
|
||||
@staticmethod
|
||||
def create_role(name: str, rank: int) -> Role:
|
||||
"""Create a new role."""
|
||||
with session_maker() as session:
|
||||
role = Role(name=name, rank=rank)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
session.refresh(role)
|
||||
return role
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_id(role_id: int) -> Optional[Role]:
|
||||
"""Get role by ID."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.id == role_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_name(name: str) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
def list_roles() -> List[Role]:
|
||||
"""List all roles."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).order_by(Role.rank).all()
|
||||
350
enterprise/storage/saas_app_conversation_info_injector.py
Normal file
350
enterprise/storage/saas_app_conversation_info_injector.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy import func, select
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user import User
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
||||
AppConversationInfoService,
|
||||
AppConversationInfoServiceInjector,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
AppConversationInfoPage,
|
||||
AppConversationSortOrder,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
|
||||
|
||||
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
|
||||
|
||||
async def _secure_select(self):
|
||||
query = (
|
||||
select(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def _secure_select_with_saas_metadata(self):
|
||||
"""Select query that includes SAAS metadata for retrieving user_id."""
|
||||
query = (
|
||||
select(StoredConversationMetadata, StoredConversationMetadataSaas)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def search_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
include_sub_conversations: bool = False,
|
||||
) -> AppConversationInfoPage:
|
||||
"""Search for conversations with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
# Exclude sub-conversations (only include top-level conversations)
|
||||
query = query.where(
|
||||
StoredConversationMetadata.parent_conversation_id.is_(None)
|
||||
)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
# Add sort order
|
||||
if sort_order == AppConversationSortOrder.CREATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.created_at)
|
||||
elif sort_order == AppConversationSortOrder.CREATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.TITLE:
|
||||
query = query.order_by(StoredConversationMetadata.title)
|
||||
elif sort_order == AppConversationSortOrder.TITLE_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.title.desc())
|
||||
|
||||
# Apply pagination
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Apply limit and get one extra to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
for stored_metadata, saas_metadata in rows
|
||||
]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return AppConversationInfoPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def count_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count conversations matching the given filters with SAAS metadata."""
|
||||
query = (
|
||||
select(func.count(StoredConversationMetadata.conversation_id))
|
||||
.select_from(
|
||||
StoredConversationMetadata.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
# Apply user filtering
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
count = result.scalar()
|
||||
return count or 0
|
||||
|
||||
def _apply_filters_with_saas_metadata(
|
||||
self,
|
||||
query,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
):
|
||||
"""Apply filters to query that includes SAAS metadata."""
|
||||
# Apply the same filters as the base class
|
||||
conditions = []
|
||||
if title__contains is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.title.like(f'%{title__contains}%')
|
||||
)
|
||||
|
||||
if created_at__gte is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at >= created_at__gte)
|
||||
|
||||
if created_at__lt is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at < created_at__lt)
|
||||
|
||||
if updated_at__gte is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
||||
)
|
||||
|
||||
if updated_at__lt is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||
)
|
||||
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
return query
|
||||
|
||||
async def get_app_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> AppConversationInfo | None:
|
||||
"""Get conversation info with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
result_set = await self.db_session.execute(query)
|
||||
result = result_set.first()
|
||||
if result:
|
||||
stored_metadata, saas_metadata = result
|
||||
return self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
return None
|
||||
|
||||
async def batch_get_app_conversation_info(
|
||||
self, conversation_ids: list[UUID]
|
||||
) -> list[AppConversationInfo | None]:
|
||||
"""Batch get conversation info with user_id from SAAS metadata."""
|
||||
conversation_id_strs = [
|
||||
str(conversation_id) for conversation_id in conversation_ids
|
||||
]
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id.in_(conversation_id_strs)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Create a mapping of conversation_id to (metadata, saas_metadata)
|
||||
info_by_id = {}
|
||||
for stored_metadata, saas_metadata in rows:
|
||||
info_by_id[stored_metadata.conversation_id] = (
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
)
|
||||
|
||||
results: list[AppConversationInfo | None] = []
|
||||
for conversation_id in conversation_id_strs:
|
||||
if conversation_id in info_by_id:
|
||||
stored_metadata, saas_metadata = info_by_id[conversation_id]
|
||||
results.append(
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
)
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
return results
|
||||
|
||||
async def save_app_conversation_info(
|
||||
self, info: AppConversationInfo
|
||||
) -> AppConversationInfo:
|
||||
"""Save conversation info and create/update SAAS metadata with user_id and org_id."""
|
||||
# Save the base conversation metadata
|
||||
await super().save_app_conversation_info(info)
|
||||
|
||||
# Get current user_id for SAAS metadata
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
# Convert string user_id to UUID
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
user_query = select(User).where(User.id == user_id_uuid)
|
||||
result = await self.db_session.execute(user_query)
|
||||
user = result.scalar_one_or_none()
|
||||
assert user
|
||||
|
||||
# Check if SAAS metadata already exists
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(info.id)
|
||||
)
|
||||
result = await self.db_session.execute(saas_query)
|
||||
existing_saas_metadata = result.scalar_one_or_none()
|
||||
assert existing_saas_metadata is None or (
|
||||
existing_saas_metadata.user_id == user_id_uuid
|
||||
and existing_saas_metadata.org_id == user.current_org_id
|
||||
)
|
||||
|
||||
if not existing_saas_metadata:
|
||||
# Create new SAAS metadata
|
||||
# Set org_id to user_id as specified in requirements
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=str(info.id),
|
||||
user_id=user_id_uuid,
|
||||
org_id=user.current_org_id,
|
||||
)
|
||||
self.db_session.add(saas_metadata)
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
return info
|
||||
|
||||
def _to_info_with_user_id(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas,
|
||||
) -> AppConversationInfo:
|
||||
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
|
||||
# Use the base _to_info method to get the basic info
|
||||
info = self._to_info(stored)
|
||||
|
||||
# Override the created_by_user_id with the user_id from SAAS metadata
|
||||
info.created_by_user_id = (
|
||||
str(saas_metadata.user_id) if saas_metadata.user_id else None
|
||||
)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class SaasAppConversationInfoServiceInjector(AppConversationInfoServiceInjector):
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[AppConversationInfoService, None]:
|
||||
from openhands.app_server.config import (
|
||||
get_db_session,
|
||||
get_user_context,
|
||||
)
|
||||
|
||||
async with (
|
||||
get_user_context(state, request) as user_context,
|
||||
get_db_session(state, request) as db_session,
|
||||
):
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=db_session, user_context=user_context
|
||||
)
|
||||
yield service
|
||||
@@ -4,10 +4,15 @@ import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
@@ -29,20 +34,37 @@ logger = logging.getLogger(__name__)
|
||||
class SaasConversationStore(ConversationStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
org_id: UUID | None = None # will be fetched automatically
|
||||
|
||||
def __init__(self, user_id: str, session_maker: sessionmaker):
|
||||
self.user_id = user_id
|
||||
self.session_maker = session_maker
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
self.org_id = user.current_org_id if user else None
|
||||
|
||||
def _select_by_id(self, session, conversation_id: str):
|
||||
return (
|
||||
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
|
||||
query = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.user_id == UUID(self.user_id))
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
)
|
||||
|
||||
if self.org_id is not None:
|
||||
query = query.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
|
||||
return query
|
||||
|
||||
def _to_external_model(self, conversation_metadata: StoredConversationMetadata):
|
||||
kwargs = {
|
||||
c.name: getattr(conversation_metadata, c.name)
|
||||
for c in StoredConversationMetadata.__table__.columns
|
||||
if c.name != 'github_user_id' # Skip github_user_id field
|
||||
}
|
||||
# TODO: I'm not sure why the timezone is not set on the dates coming back out of the db
|
||||
kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC)
|
||||
@@ -53,6 +75,8 @@ class SaasConversationStore(ConversationStore):
|
||||
# Convert string to ProviderType enum
|
||||
kwargs['git_provider'] = ProviderType(kwargs['git_provider'])
|
||||
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Remove V1 attributes
|
||||
kwargs.pop('max_budget_per_task', None)
|
||||
kwargs.pop('cache_read_tokens', None)
|
||||
@@ -66,7 +90,10 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
async def save_metadata(self, metadata: ConversationMetadata):
|
||||
kwargs = dataclasses.asdict(metadata)
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata
|
||||
kwargs.pop('user_id', None)
|
||||
kwargs.pop('org_id', None)
|
||||
|
||||
# Convert ProviderType enum to string for storage
|
||||
if kwargs.get('git_provider') is not None:
|
||||
@@ -80,7 +107,41 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
def _save_metadata():
|
||||
with self.session_maker() as session:
|
||||
# Save the main conversation metadata
|
||||
session.merge(stored_metadata)
|
||||
|
||||
# Create or update the SaaS metadata record
|
||||
saas_metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== stored_metadata.conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not saas_metadata:
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=stored_metadata.conversation_id,
|
||||
user_id=UUID(self.user_id),
|
||||
org_id=self.org_id,
|
||||
)
|
||||
session.add(saas_metadata)
|
||||
else:
|
||||
# Validate
|
||||
expected_user_id = UUID(self.user_id)
|
||||
expected_org_id = self.org_id
|
||||
|
||||
if saas_metadata.user_id != expected_user_id:
|
||||
raise ValueError(
|
||||
f'Existing user_id ({saas_metadata.user_id}) does not match expected value ({expected_user_id}).'
|
||||
)
|
||||
|
||||
if expected_org_id and saas_metadata.org_id != expected_org_id:
|
||||
raise ValueError(
|
||||
f'Existing org_id ({saas_metadata.org_id}) does not match expected value ({expected_org_id}).'
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_metadata)
|
||||
@@ -100,8 +161,29 @@ class SaasConversationStore(ConversationStore):
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
def _delete_metadata():
|
||||
with self.session_maker() as session:
|
||||
self._select_by_id(session, conversation_id).delete()
|
||||
session.commit()
|
||||
saas_record = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id),
|
||||
StoredConversationMetadataSaas.org_id == self.org_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if saas_record:
|
||||
# Delete both records, but only if the SaaS one exists
|
||||
session.query(StoredConversationMetadata).filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
).delete()
|
||||
|
||||
session.delete(saas_record)
|
||||
|
||||
session.commit()
|
||||
else:
|
||||
# No SaaS record found → skip deleting main metadata
|
||||
session.rollback()
|
||||
|
||||
await call_sync_from_async(_delete_metadata)
|
||||
|
||||
@@ -124,7 +206,15 @@ class SaasConversationStore(ConversationStore):
|
||||
with self.session_maker() as session:
|
||||
conversations = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id)
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
.filter(StoredConversationMetadata.conversation_version == 'V0')
|
||||
.order_by(StoredConversationMetadata.created_at.desc())
|
||||
.offset(offset)
|
||||
|
||||
@@ -8,11 +8,13 @@ from cryptography.fernet import Fernet
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -24,14 +26,17 @@ class SaasSecretsStore(SecretsStore):
|
||||
async def load(self) -> Secrets | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Fetch all secrets for the given user ID
|
||||
settings = (
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
query = session.query(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
if org_id is not None:
|
||||
query = query.filter(StoredCustomSecrets.org_id == org_id)
|
||||
settings = query.all()
|
||||
|
||||
if not settings:
|
||||
return Secrets()
|
||||
@@ -48,6 +53,8 @@ class SaasSecretsStore(SecretsStore):
|
||||
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
|
||||
|
||||
async def store(self, item: Secrets):
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete all existing records and override with incoming ones
|
||||
@@ -76,6 +83,7 @@ class SaasSecretsStore(SecretsStore):
|
||||
for secret_name, secret_value, description in secret_tuples:
|
||||
new_secret = StoredCustomSecrets(
|
||||
keycloak_user_id=self.user_id,
|
||||
org_id=org_id,
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
description=description,
|
||||
|
||||
@@ -2,45 +2,37 @@ from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from base64 import b64decode, b64encode
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
from cryptography.fernet import Fernet
|
||||
from integrations import stripe_service
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
CURRENT_USER_SETTINGS_VERSION,
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
REQUIRE_PAYMENT,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import joinedload, sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore as OssSettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaasSettingsStore(SettingsStore):
|
||||
class SaasSettingsStore(OssSettingsStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
|
||||
|
||||
def get_user_settings_by_keycloak_id(
|
||||
def _get_user_settings_by_keycloak_id(
|
||||
self, keycloak_user_id: str, session=None
|
||||
) -> UserSettings | None:
|
||||
"""
|
||||
@@ -76,246 +68,104 @@ class SaasSettingsStore(SettingsStore):
|
||||
return _get_settings()
|
||||
|
||||
async def load(self) -> Settings | None:
|
||||
if not self.user_id:
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
if not user:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
return None
|
||||
with self.session_maker() as session:
|
||||
settings = self.get_user_settings_by_keycloak_id(self.user_id, session)
|
||||
|
||||
if not settings or settings.user_version != CURRENT_USER_SETTINGS_VERSION:
|
||||
logger.info(
|
||||
'saas_settings_store:load:triggering_migration',
|
||||
extra={'user_id': self.user_id},
|
||||
org_id = user.current_org_id
|
||||
org_member: OrgMember = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
kwargs = {
|
||||
**{
|
||||
normalized: getattr(org, c.name)
|
||||
for c in Org.__table__.columns
|
||||
if (
|
||||
normalized := c.name.removeprefix('_default_')
|
||||
.removeprefix('default_')
|
||||
.lstrip('_')
|
||||
)
|
||||
return await self.create_default_settings(settings)
|
||||
kwargs = {
|
||||
c.name: getattr(settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
in Settings.model_fields
|
||||
},
|
||||
**{
|
||||
normalized: getattr(user, c.name)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) in Settings.model_fields
|
||||
},
|
||||
}
|
||||
kwargs['llm_api_key'] = org_member.llm_api_key
|
||||
if org_member.max_iterations:
|
||||
kwargs['max_iterations'] = org_member.max_iterations
|
||||
if org_member.llm_model:
|
||||
kwargs['llm_model'] = org_member.llm_model
|
||||
if org_member.llm_api_key_for_byor:
|
||||
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
|
||||
if org_member.llm_base_url:
|
||||
kwargs['llm_base_url'] = org_member.llm_base_url
|
||||
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if item and self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item)
|
||||
|
||||
# Call the static store method from SettingsStore
|
||||
with self.session_maker() as session:
|
||||
existing = None
|
||||
kwargs = {}
|
||||
if item:
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
self._encrypt_kwargs(kwargs)
|
||||
# First check if we have an existing entry in the new table
|
||||
existing = self.get_user_settings_by_keycloak_id(self.user_id, session)
|
||||
|
||||
kwargs = {
|
||||
key: value
|
||||
for key, value in kwargs.items()
|
||||
if key in UserSettings.__table__.columns
|
||||
}
|
||||
if existing:
|
||||
# Update existing entry
|
||||
for key, value in kwargs.items():
|
||||
setattr(existing, key, value)
|
||||
existing.user_version = CURRENT_USER_SETTINGS_VERSION
|
||||
session.merge(existing)
|
||||
else:
|
||||
kwargs['keycloak_user_id'] = self.user_id
|
||||
kwargs['user_version'] = CURRENT_USER_SETTINGS_VERSION
|
||||
kwargs.pop('secrets_store', None) # Don't save secrets_store to db
|
||||
settings = UserSettings(**kwargs)
|
||||
session.add(settings)
|
||||
session.commit()
|
||||
|
||||
async def create_default_settings(self, user_settings: UserSettings | None):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not self.user_id:
|
||||
return None
|
||||
|
||||
# Only users that have specified a payment method get default settings
|
||||
if REQUIRE_PAYMENT and not await stripe_service.has_payment_method(
|
||||
self.user_id
|
||||
):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:no_payment',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
settings: Settings | None = None
|
||||
if user_settings is None:
|
||||
settings = Settings(
|
||||
language='en',
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
elif isinstance(user_settings, UserSettings):
|
||||
# Convert UserSettings (SQLAlchemy model) to Settings (Pydantic model)
|
||||
kwargs = {
|
||||
c.name: getattr(user_settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
|
||||
if settings:
|
||||
settings = await self.update_settings_with_litellm_default(settings)
|
||||
if settings is None:
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:litellm_update_failed',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
await self.store(settings)
|
||||
return settings
|
||||
|
||||
async def load_legacy_file_store_settings(self, github_user_id: str):
|
||||
if not github_user_id:
|
||||
return None
|
||||
|
||||
file_store = get_file_store(self.config.file_store, self.config.file_store_path)
|
||||
path = f'users/github/{github_user_id}/settings.json'
|
||||
|
||||
try:
|
||||
json_str = await call_sync_from_async(file_store.read, path)
|
||||
logger.info(
|
||||
'saas_settings_store:load_legacy_file_store_settings:found',
|
||||
extra={'github_user_id': github_user_id},
|
||||
)
|
||||
kwargs = json.loads(json_str)
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'saas_settings_store:load_legacy_file_store_settings:error',
|
||||
extra={'github_user_id': github_user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
async def update_settings_with_litellm_default(
|
||||
self, settings: Settings
|
||||
) -> Settings | None:
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(self.user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
# Get the previous max budget to prevent accidental loss
|
||||
# In Litellm a get always succeeds, regardless of whether the user actually exists
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={self.user_id}'
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
user_info = response_json.get('user_info') or {}
|
||||
logger.info(
|
||||
f'creating_litellm_user: {self.user_id}; prev_max_budget: {user_info.get("max_budget")}; prev_metadata: {user_info.get("metadata")}'
|
||||
)
|
||||
max_budget = user_info.get('max_budget') or DEFAULT_INITIAL_BUDGET
|
||||
spend = user_info.get('spend') or 0
|
||||
if not item:
|
||||
return None
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(self.user_id))
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
# Check if we need to migrate from user_settings
|
||||
user_settings = None
|
||||
with session_maker() as session:
|
||||
user_settings = self.get_user_settings_by_keycloak_id(
|
||||
user_settings = self._get_user_settings_by_keycloak_id(
|
||||
self.user_id, session
|
||||
)
|
||||
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
|
||||
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
|
||||
if (
|
||||
user_settings
|
||||
and user_settings.user_version < 4
|
||||
and user_settings.billing_margin
|
||||
and user_settings.billing_margin != 1.0
|
||||
):
|
||||
billing_margin = user_settings.billing_margin
|
||||
logger.info(
|
||||
'user_settings_v4_budget_upgrade',
|
||||
extra={
|
||||
'max_budget': max_budget,
|
||||
'billing_margin': billing_margin,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
max_budget *= billing_margin
|
||||
spend *= billing_margin
|
||||
user_settings.billing_margin = 1.0
|
||||
session.commit()
|
||||
|
||||
email = keycloak_user_info.get('email')
|
||||
|
||||
# We explicitly delete here to guard against odd inherited settings on upgrade.
|
||||
# We don't care if this fails with a 404
|
||||
await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [self.user_id]}
|
||||
)
|
||||
|
||||
# Create the new litellm user
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, email, max_budget, spend
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
'duplicate_user_email',
|
||||
extra={'user_id': self.user_id, 'email': email},
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, None, max_budget, spend
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [self.user_id],
|
||||
'email': email,
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
if user_settings:
|
||||
user = await UserStore.migrate_user(self.user_id, user_settings)
|
||||
else:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
return None
|
||||
|
||||
response_json = response.json()
|
||||
key = response_json['key']
|
||||
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:user_created',
|
||||
extra={'user_id': self.user_id},
|
||||
org_id = user.current_org_id
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item, str(org_id))
|
||||
org_member = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
|
||||
settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
settings.llm_model = get_default_litellm_model()
|
||||
settings.llm_api_key = SecretStr(key)
|
||||
settings.llm_base_url = LITE_LLM_API_URL
|
||||
return settings
|
||||
for model in (user, org, org_member):
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
@@ -326,6 +176,9 @@ class SaasSettingsStore(SettingsStore):
|
||||
logger.debug(f'saas_settings_store.get_instance::{user_id}')
|
||||
return SaasSettingsStore(user_id, session_maker, config)
|
||||
|
||||
def _should_encrypt(self, key):
|
||||
return key in self.ENCRYPT_VALUES
|
||||
|
||||
def _decrypt_kwargs(self, kwargs: dict):
|
||||
fernet = self._fernet()
|
||||
for key, value in kwargs.items():
|
||||
@@ -369,21 +222,24 @@ class SaasSettingsStore(SettingsStore):
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
return Fernet(fernet_key)
|
||||
|
||||
def _should_encrypt(self, key: str) -> bool:
|
||||
return key in ('llm_api_key', 'llm_api_key_for_byor', 'search_api_key')
|
||||
|
||||
def _is_openhands_provider(self, item: Settings) -> bool:
|
||||
"""Check if the settings use the OpenHands provider."""
|
||||
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
|
||||
|
||||
async def _ensure_openhands_api_key(self, item: Settings) -> None:
|
||||
async def _ensure_openhands_api_key(self, item: Settings, org_id: str) -> None:
|
||||
"""Generate and set the OpenHands API key for the given settings.
|
||||
|
||||
First checks if an existing key with the OpenHands alias exists,
|
||||
and reuses it if found. Otherwise, generates a new key.
|
||||
"""
|
||||
# Generate new key if none exists
|
||||
generated_key = await self._generate_openhands_key()
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
self.user_id,
|
||||
org_id,
|
||||
f'Openhands Provider Key - user {self.user_id}',
|
||||
{'type': 'openhands'},
|
||||
)
|
||||
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
@@ -395,78 +251,3 @@ class SaasSettingsStore(SettingsStore):
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
|
||||
async def _create_user_in_lite_llm(
|
||||
self, client: httpx.AsyncClient, email: str | None, max_budget: int, spend: int
|
||||
):
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': email,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
'user_id': str(self.user_id),
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': True,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': CURRENT_USER_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
'key_alias': f'OpenHands Cloud - user {self.user_id}',
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
async def _generate_openhands_key(self) -> str | None:
|
||||
"""Generate a new OpenHands provider key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'saas_settings_store:_generate_openhands_key:litellm_config_not_found',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
'user_id': self.user_id,
|
||||
'metadata': {'type': 'openhands'},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'saas_settings_store:_generate_openhands_key:success',
|
||||
extra={
|
||||
'user_id': self.user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': (
|
||||
key[:10] + '...' if key and len(key) > 10 else key
|
||||
),
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'saas_settings_store:_generate_openhands_key:no_key_in_response',
|
||||
extra={'user_id': self.user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'saas_settings_store:_generate_openhands_key:error',
|
||||
extra={'user_id': self.user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -8,4 +10,8 @@ class SlackConversation(Base): # type: ignore
|
||||
conversation_id = Column(String, nullable=False, index=True)
|
||||
channel_id = Column(String, nullable=False)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
parent_id = Column(String, nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_conversations')
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Identity, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Identity, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -6,6 +8,7 @@ class SlackUser(Base): # type: ignore
|
||||
__tablename__ = 'slack_users'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
slack_user_id = Column(String, nullable=False, index=True)
|
||||
slack_display_name = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -13,3 +16,6 @@ class SlackUser(Base): # type: ignore
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_users')
|
||||
|
||||
@@ -1,8 +1,22 @@
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata as _StoredConversationMetadata,
|
||||
)
|
||||
def _get_stored_conversation_metadata():
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata as _StoredConversationMetadata,
|
||||
)
|
||||
|
||||
StoredConversationMetadata = _StoredConversationMetadata
|
||||
return _StoredConversationMetadata
|
||||
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
StoredConversationMetadata = None
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
global StoredConversationMetadata
|
||||
if name == 'StoredConversationMetadata':
|
||||
if StoredConversationMetadata is None:
|
||||
StoredConversationMetadata = _get_stored_conversation_metadata()
|
||||
return StoredConversationMetadata
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadata']
|
||||
|
||||
28
enterprise/storage/stored_conversation_metadata_saas.py
Normal file
28
enterprise/storage/stored_conversation_metadata_saas.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
SQLAlchemy model for ConversationMetadataSaas.
|
||||
|
||||
This model stores the SaaS-specific metadata for conversations,
|
||||
containing only the conversation_id, user_id, and org_id.
|
||||
"""
|
||||
|
||||
from sqlalchemy import UUID as SQL_UUID
|
||||
from sqlalchemy import Column, ForeignKey, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class StoredConversationMetadataSaas(Base): # type: ignore
|
||||
"""SaaS conversation metadata model containing user and org associations."""
|
||||
|
||||
__tablename__ = 'conversation_metadata_saas'
|
||||
|
||||
conversation_id = Column(String, primary_key=True)
|
||||
user_id = Column(SQL_UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
|
||||
# Relationships
|
||||
user = relationship('User', back_populates='stored_conversation_metadata_saas')
|
||||
org = relationship('Org', back_populates='stored_conversation_metadata_saas')
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadataSaas']
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -6,6 +8,10 @@ class StoredCustomSecrets(Base): # type: ignore
|
||||
__tablename__ = 'custom_secrets'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=True, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
secret_name = Column(String, nullable=False)
|
||||
secret_value = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='user_secrets')
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -13,6 +15,7 @@ class StripeCustomer(Base): # type: ignore
|
||||
__tablename__ = 'stripe_customers'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
stripe_customer_id = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
@@ -23,3 +26,6 @@ class StripeCustomer(Base): # type: ignore
|
||||
onupdate=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='stripe_customers')
|
||||
|
||||
41
enterprise/storage/user.py
Normal file
41
enterprise/storage/user.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
SQLAlchemy model for User.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import (
|
||||
UUID,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class User(Base): # type: ignore
|
||||
"""User model with organizational relationships."""
|
||||
|
||||
__tablename__ = 'user'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
current_org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=True)
|
||||
accepted_tos = Column(DateTime, nullable=True)
|
||||
enable_sound_notifications = Column(Boolean, nullable=True)
|
||||
language = Column(String, nullable=True)
|
||||
user_consents_to_analytics = Column(Boolean, nullable=True)
|
||||
email = Column(String, nullable=True)
|
||||
email_verified = Column(Boolean, nullable=True)
|
||||
|
||||
# Relationships
|
||||
role = relationship('Role', back_populates='users')
|
||||
org_members = relationship('OrgMember', back_populates='user')
|
||||
current_org = relationship('Org', back_populates='current_users')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='user'
|
||||
)
|
||||
@@ -38,3 +38,7 @@ class UserSettings(Base): # type: ignore
|
||||
email_verified = Column(Boolean, nullable=True)
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
already_migrated = Column(
|
||||
Boolean, nullable=True, default=False
|
||||
) # False = not migrated, True = migrated
|
||||
|
||||
332
enterprise/storage/user_store.py
Normal file
332
enterprise/storage/user_store.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
Store class for managing users.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from server.logger import logger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.encrypt_utils import decrypt_legacy_model
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
|
||||
class UserStore:
|
||||
"""Store for managing users."""
|
||||
|
||||
@staticmethod
|
||||
async def create_user(
|
||||
user_id: str,
|
||||
user_info: dict,
|
||||
role_id: Optional[int] = None,
|
||||
) -> User | None:
|
||||
"""Create a new user."""
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_name=user_info['preferred_username'],
|
||||
contact_email=user_info['email'],
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
settings = await UserStore.create_default_settings(
|
||||
org_id=str(org.id), user_id=user_id
|
||||
)
|
||||
|
||||
if not settings:
|
||||
return None
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_settings(settings)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=role_id,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_settings(settings)
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # owner of your own org.
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def migrate_user(
|
||||
user_id: str,
|
||||
user_settings: UserSettings,
|
||||
user_info: dict,
|
||||
) -> User:
|
||||
if not user_id or not user_settings:
|
||||
return None
|
||||
|
||||
kwargs = decrypt_legacy_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
],
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_name=user_info['username'],
|
||||
contact_email=user_info['email'],
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
await LiteLlmManager.migrate_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
)
|
||||
|
||||
# avoids circular reference. This migrate method is temprorary until all users are migrated.
|
||||
from integrations.stripe_service import migrate_customer
|
||||
|
||||
await migrate_customer(session, user_id, org)
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_user_settings(decrypted_user_settings)
|
||||
org_kwargs.pop('id', None)
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_user_settings(
|
||||
decrypted_user_settings
|
||||
)
|
||||
user_kwargs.pop('id', None)
|
||||
user = User(
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=None,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
org_member_kwargs = OrgMemberStore.get_kwargs_from_user_settings(
|
||||
decrypted_user_settings
|
||||
)
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # owner of your own org.
|
||||
status='active',
|
||||
**org_member_kwargs,
|
||||
)
|
||||
session.add(org_member)
|
||||
|
||||
# Mark the old user_settings as migrated instead of deleting
|
||||
user_settings.already_migrated = True
|
||||
session.merge(user_settings)
|
||||
session.flush()
|
||||
|
||||
# need to migrate conversation metadata
|
||||
session.execute(
|
||||
text("""
|
||||
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||
SELECT
|
||||
conversation_id,
|
||||
:user_id,
|
||||
:user_id
|
||||
FROM conversation_metadata
|
||||
WHERE user_id = :user_id
|
||||
"""),
|
||||
{'user_id': user_id},
|
||||
)
|
||||
|
||||
# Update org_id for tables that had org_id added
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
|
||||
# Update stripe_customers
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update slack_users
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update slack_conversation
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update api_keys
|
||||
session.execute(
|
||||
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update custom_secrets
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
# Update billing_sessions
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
|
||||
),
|
||||
{'org_id': user_uuid, 'user_id': user_uuid},
|
||||
)
|
||||
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID."""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Check if we need to migrate from user_settings
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
token_manager = TokenManager()
|
||||
user_info = call_async_from_sync(
|
||||
token_manager.get_user_info_from_user_id,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
)
|
||||
user = call_async_from_sync(
|
||||
UserStore.migrate_user,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
"""List all users."""
|
||||
with session_maker() as session:
|
||||
return session.query(User).all()
|
||||
|
||||
# Prevent circular imports
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
@staticmethod
|
||||
async def create_default_settings(
|
||||
org_id: str, user_id: str
|
||||
) -> Optional['Settings']:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:start',
|
||||
extra={'org_id': org_id, 'user_id': user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not org_id:
|
||||
return None
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
settings = Settings(language='en', enable_proactive_conversation_starters=True)
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
settings = await LiteLlmManager.create_entries(org_id, user_id, settings)
|
||||
if not settings:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:litellm_create_failed',
|
||||
extra={'org_id': org_id},
|
||||
)
|
||||
return None
|
||||
|
||||
return settings
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: 'Settings'):
|
||||
kwargs = {
|
||||
normalized: getattr(settings, normalized)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_user_settings(user_settings: UserSettings):
|
||||
kwargs = {
|
||||
normalized: getattr(user_settings, normalized)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
@@ -3,7 +3,7 @@ from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import GITLAB_WEBHOOK_URL
|
||||
from server.constants import WEB_HOST
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
@@ -11,6 +11,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.service_types import GitService
|
||||
|
||||
GITLAB_WEBHOOK_URL = f'https://{WEB_HOST}/integration/gitlab/events'
|
||||
CHUNK_SIZE = 100
|
||||
WEBHOOK_NAME = 'OpenHands Resolver'
|
||||
SCOPES: list[str] = [
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import (
|
||||
UserVersionUpgradeProcessor,
|
||||
)
|
||||
from server.constants import ORG_SETTINGS_VERSION
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
@@ -14,11 +13,20 @@ from storage.billing_session import BillingSession
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.feedback import Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user import User
|
||||
|
||||
# Import the actual StoredConversationMetadata from OpenHands core
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -67,7 +75,6 @@ def add_minimal_fixtures(session_maker):
|
||||
session.add(
|
||||
StoredConversationMetadata(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id='mock-user-id',
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
last_updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
accumulated_cost=5.25,
|
||||
@@ -76,6 +83,13 @@ def add_minimal_fixtures(session_maker):
|
||||
total_tokens=750,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredConversationMetadataSaas(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredOfflineToken(
|
||||
user_id='mock-user-id',
|
||||
@@ -84,7 +98,38 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
|
||||
session.add(
|
||||
Org(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
name='mock-org',
|
||||
org_version=ORG_SETTINGS_VERSION,
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
Role(
|
||||
id=1,
|
||||
name='admin',
|
||||
rank=1,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
User(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
current_org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_consents_to_analytics=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
OrgMember(
|
||||
org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
role_id=1,
|
||||
llm_api_key='mock-api-key',
|
||||
status='active',
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id='mock-user-id',
|
||||
@@ -93,13 +138,6 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-10'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
UserSettings(
|
||||
keycloak_user_id='mock-user-id',
|
||||
user_consents_to_analytics=True,
|
||||
user_version=CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
ConversationWork(
|
||||
conversation_id='mock-conversation-id',
|
||||
@@ -108,17 +146,6 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
maintenance_task = MaintenanceTask(
|
||||
status=MaintenanceTaskStatus.PENDING,
|
||||
)
|
||||
maintenance_task.set_processor(
|
||||
UserVersionUpgradeProcessor(
|
||||
user_ids=['mock-user-id'],
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
session.add(maintenance_task)
|
||||
session.commit()
|
||||
|
||||
|
||||
|
||||
133
enterprise/tests/unit/integrations/test_resolver_context.py
Normal file
133
enterprise/tests/unit/integrations/test_resolver_context.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Test for ResolverUserContext get_secrets conversion logic.
|
||||
|
||||
This test focuses on testing the actual ResolverUserContext implementation.
|
||||
"""
|
||||
|
||||
from types import MappingProxyType
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from enterprise.integrations.resolver_context import ResolverUserContext
|
||||
|
||||
# Import the real classes we want to test
|
||||
from openhands.integrations.provider import CustomSecret
|
||||
|
||||
# Import the SDK types we need for testing
|
||||
from openhands.sdk.conversation.secret_source import SecretSource, StaticSecret
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_saas_user_auth():
|
||||
"""Mock SaasUserAuth for testing."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resolver_context(mock_saas_user_auth):
|
||||
"""Create a ResolverUserContext instance for testing."""
|
||||
return ResolverUserContext(saas_user_auth=mock_saas_user_auth)
|
||||
|
||||
|
||||
def create_custom_secret(value: str, description: str = 'Test secret') -> CustomSecret:
|
||||
"""Helper to create CustomSecret instances."""
|
||||
return CustomSecret(secret=SecretStr(value), description=description)
|
||||
|
||||
|
||||
def create_secrets(custom_secrets_dict: dict[str, CustomSecret]) -> Secrets:
|
||||
"""Helper to create Secrets instances."""
|
||||
return Secrets(custom_secrets=MappingProxyType(custom_secrets_dict))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secrets_converts_custom_to_static(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_secrets correctly converts CustomSecret objects to StaticSecret objects."""
|
||||
# Arrange
|
||||
secrets = create_secrets(
|
||||
{
|
||||
'TEST_SECRET_1': create_custom_secret('secret_value_1'),
|
||||
'TEST_SECRET_2': create_custom_secret('secret_value_2'),
|
||||
}
|
||||
)
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(secret, StaticSecret) for secret in result.values())
|
||||
assert result['TEST_SECRET_1'].value.get_secret_value() == 'secret_value_1'
|
||||
assert result['TEST_SECRET_2'].value.get_secret_value() == 'secret_value_2'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secrets_with_special_characters(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that secret values with special characters are preserved during conversion."""
|
||||
# Arrange
|
||||
special_value = 'very_secret_password_123!@#$%^&*()'
|
||||
secrets = create_secrets({'SPECIAL_SECRET': create_custom_secret(special_value)})
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert isinstance(result['SPECIAL_SECRET'], StaticSecret)
|
||||
assert result['SPECIAL_SECRET'].value.get_secret_value() == special_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'secrets_input,expected_result',
|
||||
[
|
||||
(None, {}), # No secrets available
|
||||
(create_secrets({}), {}), # Empty custom secrets
|
||||
],
|
||||
)
|
||||
async def test_get_secrets_empty_cases(
|
||||
resolver_context, mock_saas_user_auth, secrets_input, expected_result
|
||||
):
|
||||
"""Test that get_secrets handles empty cases correctly."""
|
||||
# Arrange
|
||||
mock_saas_user_auth.get_secrets.return_value = secrets_input
|
||||
|
||||
# Act
|
||||
result = await resolver_context.get_secrets()
|
||||
|
||||
# Assert
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_static_secret_is_valid_secret_source():
|
||||
"""Test that StaticSecret is a valid SecretSource for SDK validation."""
|
||||
# Arrange & Act
|
||||
static_secret = StaticSecret(value='test_secret_123')
|
||||
|
||||
# Assert
|
||||
assert isinstance(static_secret, StaticSecret)
|
||||
assert isinstance(static_secret, SecretSource)
|
||||
assert static_secret.value.get_secret_value() == 'test_secret_123'
|
||||
|
||||
|
||||
def test_custom_to_static_conversion():
|
||||
"""Test the complete conversion flow from CustomSecret to StaticSecret."""
|
||||
# Arrange
|
||||
secret_value = 'conversion_test_secret'
|
||||
custom_secret = create_custom_secret(secret_value, 'Conversion test')
|
||||
|
||||
# Act - simulate the conversion logic from the actual method
|
||||
extracted_value = custom_secret.secret.get_secret_value()
|
||||
static_secret = StaticSecret(value=extracted_value)
|
||||
|
||||
# Assert
|
||||
assert isinstance(static_secret, StaticSecret)
|
||||
assert isinstance(static_secret, SecretSource)
|
||||
assert static_secret.value.get_secret_value() == secret_value
|
||||
@@ -6,22 +6,32 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import BackgroundTasks, HTTPException, Request, status
|
||||
from server.routes.event_webhook import (
|
||||
BatchMethod,
|
||||
BatchOperation,
|
||||
_get_session_api_key,
|
||||
_get_user_id,
|
||||
_parse_conversation_id_and_subpath,
|
||||
_process_batch_operations_background,
|
||||
on_batch_write,
|
||||
on_delete,
|
||||
on_write,
|
||||
|
||||
# Import the actual StoredConversationMetadata from OpenHands core
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from server.utils.conversation_callback_utils import (
|
||||
process_event,
|
||||
update_conversation_metadata,
|
||||
)
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
# Mock the lazy import to return the actual class
|
||||
with patch(
|
||||
'storage.stored_conversation_metadata.StoredConversationMetadata',
|
||||
StoredConversationMetadata,
|
||||
):
|
||||
from server.routes.event_webhook import (
|
||||
BatchMethod,
|
||||
BatchOperation,
|
||||
_get_session_api_key,
|
||||
_get_user_id,
|
||||
_parse_conversation_id_and_subpath,
|
||||
_process_batch_operations_background,
|
||||
on_batch_write,
|
||||
on_delete,
|
||||
on_write,
|
||||
)
|
||||
from server.utils.conversation_callback_utils import (
|
||||
process_event,
|
||||
update_conversation_metadata,
|
||||
)
|
||||
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
@@ -82,7 +92,7 @@ class TestGetUserId:
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
user_id = _get_user_id('mock-conversation-id')
|
||||
assert user_id == 'mock-user-id'
|
||||
assert user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
|
||||
def test_get_user_id_conversation_not_found(self, session_maker):
|
||||
"""Test getting user ID when conversation doesn't exist."""
|
||||
@@ -105,10 +115,12 @@ class TestGetSessionApiKey:
|
||||
return_value=[mock_agent_loop_info]
|
||||
)
|
||||
|
||||
api_key = await _get_session_api_key('user-123', 'conv-456')
|
||||
api_key = await _get_session_api_key(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
|
||||
)
|
||||
assert api_key == 'test-api-key'
|
||||
mock_manager.get_agent_loop_info.assert_called_once_with(
|
||||
'user-123', filter_to_sids={'conv-456'}
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', filter_to_sids={'conv-456'}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -118,7 +130,9 @@ class TestGetSessionApiKey:
|
||||
mock_manager.get_agent_loop_info = AsyncMock(return_value=[])
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
await _get_session_api_key('user-123', 'conv-456')
|
||||
await _get_session_api_key(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
|
||||
)
|
||||
|
||||
|
||||
class TestProcessEvent:
|
||||
@@ -142,10 +156,15 @@ class TestProcessEvent:
|
||||
mock_event = MagicMock()
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
|
||||
mock_file_store.write.assert_called_once_with(
|
||||
'users/user-123/conversations/conv-456/events/event-1.json',
|
||||
'users/5594c7b6-f959-4b81-92e9-b09c206f5081/conversations/conv-456/events/event-1.json',
|
||||
json.dumps(content),
|
||||
)
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
@@ -177,14 +196,19 @@ class TestProcessEvent:
|
||||
)
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
|
||||
mock_file_store.write.assert_called_once()
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
mock_invoke_callbacks.assert_called_once_with('conv-456', mock_event)
|
||||
mock_update_working_seconds.assert_called_once()
|
||||
mock_event_store_class.assert_called_once_with(
|
||||
'conv-456', mock_file_store, 'user-123'
|
||||
'conv-456', mock_file_store, '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -212,7 +236,12 @@ class TestProcessEvent:
|
||||
mock_event.agent_state = 'running' # Set RUNNING state to skip the update
|
||||
mock_event_from_dict.return_value = mock_event
|
||||
|
||||
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
|
||||
await process_event(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
'conv-456',
|
||||
'events/event-1.json',
|
||||
content,
|
||||
)
|
||||
|
||||
mock_file_store.write.assert_called_once()
|
||||
mock_event_from_dict.assert_called_once_with(content)
|
||||
@@ -236,10 +265,13 @@ class TestUpdateConversationMetadata:
|
||||
'total_tokens': 1500,
|
||||
}
|
||||
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
|
||||
original_session_maker = callback_utils.session_maker
|
||||
|
||||
try:
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
update_conversation_metadata('mock-conversation-id', content)
|
||||
|
||||
# Verify the conversation was updated
|
||||
@@ -257,6 +289,9 @@ class TestUpdateConversationMetadata:
|
||||
assert conversation.completion_tokens == 500
|
||||
assert conversation.total_tokens == 1500
|
||||
assert isinstance(conversation.last_updated_at, datetime)
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
|
||||
def test_update_conversation_metadata_partial_fields(
|
||||
self, session_maker_with_minimal_fixtures
|
||||
@@ -264,10 +299,13 @@ class TestUpdateConversationMetadata:
|
||||
"""Test updating conversation metadata with only some fields."""
|
||||
content = {'accumulated_cost': 15.75, 'prompt_tokens': 2000}
|
||||
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
|
||||
original_session_maker = callback_utils.session_maker
|
||||
|
||||
try:
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
update_conversation_metadata('mock-conversation-id', content)
|
||||
|
||||
# Verify only specified fields were updated, others remain unchanged
|
||||
@@ -285,6 +323,9 @@ class TestUpdateConversationMetadata:
|
||||
# These should remain as original values from fixtures
|
||||
assert conversation.completion_tokens == 250
|
||||
assert conversation.total_tokens == 750
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
|
||||
def test_update_conversation_metadata_empty_content(
|
||||
self, session_maker_with_minimal_fixtures
|
||||
@@ -292,10 +333,13 @@ class TestUpdateConversationMetadata:
|
||||
"""Test updating conversation metadata with empty content."""
|
||||
content: dict[str, float] = {}
|
||||
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
|
||||
original_session_maker = callback_utils.session_maker
|
||||
|
||||
try:
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
update_conversation_metadata('mock-conversation-id', content)
|
||||
|
||||
# Verify only last_updated_at was changed
|
||||
@@ -314,6 +358,9 @@ class TestUpdateConversationMetadata:
|
||||
assert conversation.completion_tokens == 250
|
||||
assert conversation.total_tokens == 750
|
||||
assert isinstance(conversation.last_updated_at, datetime)
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
|
||||
|
||||
class TestOnDelete:
|
||||
@@ -344,24 +391,31 @@ class TestOnWrite:
|
||||
content = {'accumulated_cost': 20.0}
|
||||
mock_request.json.return_value = content
|
||||
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key:
|
||||
mock_get_api_key.return_value = 'correct-api-key'
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
|
||||
result = await on_write(
|
||||
'sessions/mock-conversation-id/metadata.json',
|
||||
mock_request,
|
||||
'correct-api-key',
|
||||
)
|
||||
original_session_maker = callback_utils.session_maker
|
||||
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
try:
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key:
|
||||
mock_get_api_key.return_value = 'correct-api-key'
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
|
||||
result = await on_write(
|
||||
'sessions/mock-conversation-id/metadata.json',
|
||||
mock_request,
|
||||
'correct-api-key',
|
||||
)
|
||||
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_write_events_success(
|
||||
@@ -569,31 +623,38 @@ class TestProcessBatchOperationsBackground:
|
||||
)
|
||||
]
|
||||
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key, patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
mock_get_api_key.return_value = 'correct-api-key'
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
|
||||
# Should not raise any exceptions
|
||||
await _process_batch_operations_background(batch_ops, 'correct-api-key')
|
||||
original_session_maker = callback_utils.session_maker
|
||||
|
||||
# Verify the conversation metadata was updated
|
||||
with session_maker_with_minimal_fixtures() as session:
|
||||
conversation = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id
|
||||
== 'mock-conversation-id'
|
||||
try:
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key:
|
||||
mock_get_api_key.return_value = 'correct-api-key'
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
|
||||
# Should not raise any exceptions
|
||||
await _process_batch_operations_background(batch_ops, 'correct-api-key')
|
||||
|
||||
# Verify the conversation metadata was updated
|
||||
with session_maker_with_minimal_fixtures() as session:
|
||||
conversation = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id
|
||||
== 'mock-conversation-id'
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
assert conversation.accumulated_cost == 15.0
|
||||
assert conversation.accumulated_cost == 15.0
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_batch_operations_events_success(
|
||||
@@ -644,20 +705,27 @@ class TestProcessBatchOperationsBackground:
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key, patch(
|
||||
'server.utils.conversation_callback_utils.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
):
|
||||
# First call succeeds, second fails
|
||||
mock_get_api_key.side_effect = ['correct-api-key', 'wrong-api-key']
|
||||
# Import the module and patch the session_maker at the module level
|
||||
import server.utils.conversation_callback_utils as callback_utils
|
||||
|
||||
# Should not raise exceptions, just log errors
|
||||
await _process_batch_operations_background(batch_ops, 'correct-api-key')
|
||||
original_session_maker = callback_utils.session_maker
|
||||
|
||||
try:
|
||||
with patch(
|
||||
'server.routes.event_webhook.session_maker',
|
||||
session_maker_with_minimal_fixtures,
|
||||
), patch(
|
||||
'server.routes.event_webhook._get_session_api_key'
|
||||
) as mock_get_api_key:
|
||||
# First call succeeds, second fails
|
||||
mock_get_api_key.side_effect = ['correct-api-key', 'wrong-api-key']
|
||||
callback_utils.session_maker = session_maker_with_minimal_fixtures
|
||||
|
||||
# Should not raise exceptions, just log errors
|
||||
await _process_batch_operations_background(batch_ops, 'correct-api-key')
|
||||
finally:
|
||||
# Restore the original session_maker
|
||||
callback_utils.session_maker = original_session_maker
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_batch_operations_invalid_method_skipped(
|
||||
|
||||
@@ -0,0 +1,371 @@
|
||||
"""Tests for SaasSQLAppConversationInfoService.
|
||||
|
||||
This module tests the SAAS implementation of SQLAppConversationInfoService,
|
||||
focusing on user isolation, SAAS metadata handling, and multi-tenant functionality.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Import the SAAS service
|
||||
from enterprise.storage.saas_app_conversation_info_injector import (
|
||||
SaasSQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance for testing."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_with_user(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance with a user_id for testing."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111'),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_info() -> AppConversationInfo:
|
||||
"""Create a sample AppConversationInfo for testing."""
|
||||
return AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
sandbox_id='sandbox_123',
|
||||
selected_repository='https://github.com/test/repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Test Conversation',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[123, 456],
|
||||
llm_model='gpt-4',
|
||||
metrics=None,
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multiple_conversation_infos() -> list[AppConversationInfo]:
|
||||
"""Create multiple AppConversationInfo instances for testing."""
|
||||
base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
return [
|
||||
AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=None,
|
||||
sandbox_id=f'sandbox_{i}',
|
||||
selected_repository=f'https://github.com/test/repo{i}',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title=f'Test Conversation {i}',
|
||||
trigger=ConversationTrigger.GUI,
|
||||
pr_number=[i * 100],
|
||||
llm_model='gpt-4',
|
||||
metrics=None,
|
||||
created_at=base_time.replace(hour=12 + i),
|
||||
updated_at=base_time.replace(hour=12 + i, minute=30),
|
||||
)
|
||||
for i in range(1, 6) # Create 5 conversations
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user1_context():
|
||||
"""Create user context for user1."""
|
||||
return SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user2_context():
|
||||
"""Create user context for user2."""
|
||||
return SpecifyUserContext(user_id='b2222222-2222-2222-2222-222222222222')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def saas_service_user1(mock_db_session, user1_context):
|
||||
"""Create a SaasSQLAppConversationInfoService instance for user1."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=mock_db_session, user_context=user1_context
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def saas_service_user2(mock_db_session, user2_context):
|
||||
"""Create a SaasSQLAppConversationInfoService instance for user2."""
|
||||
return SaasSQLAppConversationInfoService(
|
||||
db_session=mock_db_session, user_context=user2_context
|
||||
)
|
||||
|
||||
|
||||
class TestSaasSQLAppConversationInfoService:
|
||||
"""Test suite for SaasSQLAppConversationInfoService."""
|
||||
|
||||
def test_service_initialization(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
user1_context: SpecifyUserContext,
|
||||
):
|
||||
"""Test that the SAAS service is properly initialized."""
|
||||
assert saas_service_user1.user_context == user1_context
|
||||
assert saas_service_user1.db_session is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_context_isolation(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
saas_service_user2: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that different service instances have different user contexts."""
|
||||
user1_id = await saas_service_user1.user_context.get_user_id()
|
||||
user2_id = await saas_service_user2.user_context.get_user_id()
|
||||
|
||||
assert user1_id == 'a1111111-1111-1111-1111-111111111111'
|
||||
assert user2_id == 'b2222222-2222-2222-2222-222222222222'
|
||||
assert user1_id != user2_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_includes_user_filtering(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that _secure_select method includes user filtering."""
|
||||
# This test verifies that the _secure_select method exists and can be called
|
||||
# The actual SQL generation is tested implicitly through integration
|
||||
query = await saas_service_user1._secure_select()
|
||||
assert query is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_info_with_user_id_functionality(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
):
|
||||
"""Test that _to_info_with_user_id properly sets user_id from SAAS metadata."""
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
# Create mock metadata objects
|
||||
stored_metadata = MagicMock(spec=StoredConversationMetadata)
|
||||
stored_metadata.conversation_id = '12345678-1234-5678-1234-567812345678'
|
||||
stored_metadata.parent_conversation_id = None
|
||||
stored_metadata.title = 'Test Conversation'
|
||||
stored_metadata.sandbox_id = 'test-sandbox'
|
||||
stored_metadata.selected_repository = None
|
||||
stored_metadata.selected_branch = None
|
||||
stored_metadata.git_provider = None
|
||||
stored_metadata.trigger = None
|
||||
stored_metadata.pr_number = []
|
||||
stored_metadata.llm_model = None
|
||||
from datetime import datetime, timezone
|
||||
|
||||
stored_metadata.created_at = datetime.now(timezone.utc)
|
||||
stored_metadata.last_updated_at = datetime.now(timezone.utc)
|
||||
stored_metadata.accumulated_cost = 0.0
|
||||
stored_metadata.prompt_tokens = 0
|
||||
stored_metadata.completion_tokens = 0
|
||||
stored_metadata.total_tokens = 0
|
||||
stored_metadata.max_budget_per_task = None
|
||||
stored_metadata.cache_read_tokens = 0
|
||||
stored_metadata.cache_write_tokens = 0
|
||||
stored_metadata.reasoning_tokens = 0
|
||||
stored_metadata.context_window = 0
|
||||
stored_metadata.per_turn_token = 0
|
||||
|
||||
saas_metadata = MagicMock(spec=StoredConversationMetadataSaas)
|
||||
saas_metadata.user_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
saas_metadata.org_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
|
||||
# Test the _to_info_with_user_id method
|
||||
result = saas_service_user1._to_info_with_user_id(
|
||||
stored_metadata, saas_metadata
|
||||
)
|
||||
|
||||
# Verify that the user_id from SAAS metadata is used
|
||||
assert result.created_by_user_id == 'a1111111-1111-1111-1111-111111111111'
|
||||
assert result.title == 'Test Conversation'
|
||||
assert result.sandbox_id == 'test-sandbox'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(
|
||||
self,
|
||||
async_session: AsyncSession,
|
||||
multiple_conversation_infos: list[AppConversationInfo],
|
||||
):
|
||||
"""Test that user isolation works correctly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database session execute method to return mock users
|
||||
# This mock intercepts User queries and returns a mock user object
|
||||
# with user_id and org_id the same as the user_id_uuid from the query
|
||||
original_execute = async_session.execute
|
||||
|
||||
async def mock_execute(query):
|
||||
query_str = str(query)
|
||||
|
||||
# Check if this is a User query
|
||||
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
|
||||
# Extract the UUID from the query parameters
|
||||
# The query will have bound parameters, we need to get the UUID value
|
||||
if hasattr(query, 'compile'):
|
||||
try:
|
||||
compiled = query.compile(compile_kwargs={'literal_binds': True})
|
||||
query_with_params = str(compiled)
|
||||
|
||||
# Extract UUID from the query string
|
||||
import re
|
||||
|
||||
# Try both formats: with dashes and without dashes
|
||||
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
||||
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
|
||||
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_with_dashes, query_with_params
|
||||
)
|
||||
if not uuid_match:
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_without_dashes, query_with_params
|
||||
)
|
||||
|
||||
if uuid_match:
|
||||
user_id_str = uuid_match.group(0)
|
||||
# If the UUID doesn't have dashes, add them
|
||||
if len(user_id_str) == 32 and '-' not in user_id_str:
|
||||
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
|
||||
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
|
||||
# Create a mock user with user_id and org_id the same as user_id_uuid
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.id = user_id_uuid
|
||||
mock_user.current_org_id = user_id_uuid
|
||||
|
||||
# Create a mock result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_user
|
||||
return mock_result
|
||||
except Exception:
|
||||
# If there's any error in parsing, fall back to original execute
|
||||
pass
|
||||
|
||||
# For all other queries, use the original execute method
|
||||
return await original_execute(query)
|
||||
|
||||
# Apply the mock
|
||||
async_session.execute = mock_execute
|
||||
|
||||
# Create services for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='a1111111-1111-1111-1111-111111111111'
|
||||
),
|
||||
)
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='b2222222-2222-2222-2222-222222222222'
|
||||
),
|
||||
)
|
||||
|
||||
# Create conversations for different users
|
||||
user1_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User 1 Conversation',
|
||||
)
|
||||
|
||||
user2_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='b2222222-2222-2222-2222-222222222222',
|
||||
sandbox_id='sandbox_user2',
|
||||
title='User 2 Conversation',
|
||||
)
|
||||
|
||||
# Save conversations
|
||||
await user1_service.save_app_conversation_info(user1_info)
|
||||
await user2_service.save_app_conversation_info(user2_info)
|
||||
|
||||
# User 1 should only see their conversation
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 1
|
||||
assert (
|
||||
user1_page.items[0].created_by_user_id
|
||||
== 'a1111111-1111-1111-1111-111111111111'
|
||||
)
|
||||
|
||||
# User 2 should only see their conversation
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 1
|
||||
assert (
|
||||
user2_page.items[0].created_by_user_id
|
||||
== 'b2222222-2222-2222-2222-222222222222'
|
||||
)
|
||||
|
||||
# User 1 should not be able to get user 2's conversation
|
||||
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
||||
assert user2_from_user1 is None
|
||||
|
||||
# User 2 should not be able to get user 1's conversation
|
||||
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
||||
assert user1_from_user2 is None
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
@@ -19,6 +19,14 @@ def mock_session_maker(mock_session):
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock user with org_id."""
|
||||
user = MagicMock()
|
||||
user.current_org_id = 'test-org-123'
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_store(mock_session_maker):
|
||||
return ApiKeyStore(mock_session_maker)
|
||||
@@ -31,11 +39,13 @@ def test_generate_api_key(api_key_store):
|
||||
assert len(key) == 32
|
||||
|
||||
|
||||
def test_create_api_key(api_key_store, mock_session):
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
"""Test creating an API key."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
name = 'Test Key'
|
||||
mock_get_user.return_value = mock_user
|
||||
api_key_store.generate_api_key = MagicMock(return_value='test-api-key')
|
||||
|
||||
# Execute
|
||||
@@ -43,10 +53,15 @@ def test_create_api_key(api_key_store, mock_session):
|
||||
|
||||
# Verify
|
||||
assert result == 'test-api-key'
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
api_key_store.generate_api_key.assert_called_once()
|
||||
|
||||
# Verify the ApiKey was created with the correct org_id
|
||||
added_api_key = mock_session.add.call_args[0][0]
|
||||
assert added_api_key.org_id == mock_user.current_org_id
|
||||
|
||||
|
||||
def test_validate_api_key_valid(api_key_store, mock_session):
|
||||
"""Test validating a valid API key."""
|
||||
@@ -158,10 +173,12 @@ def test_delete_api_key_by_id(api_key_store, mock_session):
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_list_api_keys(api_key_store, mock_session):
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
"""Test listing API keys for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
now = datetime.now(UTC)
|
||||
mock_key1 = MagicMock()
|
||||
mock_key1.id = 1
|
||||
@@ -177,15 +194,17 @@ def test_list_api_keys(api_key_store, mock_session):
|
||||
mock_key2.last_used_at = None
|
||||
mock_key2.expires_at = None
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
mock_key1,
|
||||
mock_key2,
|
||||
]
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_key1, mock_key2]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.list_api_keys(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert len(result) == 2
|
||||
assert result[0]['id'] == 1
|
||||
assert result[0]['name'] == 'Key 1'
|
||||
@@ -198,3 +217,59 @@ def test_list_api_keys(api_key_store, mock_session):
|
||||
assert result[1]['created_at'] == now
|
||||
assert result[1]['last_used_at'] is None
|
||||
assert result[1]['expires_at'] is None
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
"""Test retrieving MCP API key for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
mock_mcp_key = MagicMock()
|
||||
mock_mcp_key.name = 'MCP_API_KEY'
|
||||
mock_mcp_key.key = 'mcp-test-key'
|
||||
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result == 'mcp-test-key'
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key_not_found(
|
||||
mock_get_user, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test retrieving MCP API key when none exists."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result is None
|
||||
|
||||
@@ -127,6 +127,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -140,6 +141,15 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
|
||||
# Mock the user creation
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = None
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = False
|
||||
|
||||
@@ -161,20 +171,19 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.set_response_cookie') as mock_set_cookie,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
):
|
||||
# Mock the session and query results
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Mock user settings with accepted_tos
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -226,20 +235,20 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
),
|
||||
patch('server.routes.auth.KEYCLOAK_REALM_NAME', 'test-realm'),
|
||||
patch('server.routes.auth.KEYCLOAK_CLIENT_ID', 'test-client'),
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
):
|
||||
# Mock the session and query results
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
# Setup UserStore mocks
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
|
||||
# Mock user settings with accepted_tos
|
||||
mock_user_settings = MagicMock()
|
||||
mock_user_settings.accepted_tos = '2025-01-01'
|
||||
mock_query.first.return_value = mock_user_settings
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from fastapi import HTTPException, Request, status
|
||||
from httpx import HTTPStatusError, Response
|
||||
from integrations.stripe_service import has_payment_method
|
||||
from httpx import Response
|
||||
from server.routes import billing
|
||||
from server.routes.billing import (
|
||||
CreateBillingSessionResponse,
|
||||
CreateCheckoutSessionRequest,
|
||||
GetCreditsResponse,
|
||||
cancel_callback,
|
||||
cancel_subscription,
|
||||
create_checkout_session,
|
||||
create_subscription_checkout_session,
|
||||
create_customer_setup_session,
|
||||
get_credits,
|
||||
has_payment_method,
|
||||
success_callback,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.stripe_customer import Base as StripeCustomerBase
|
||||
|
||||
|
||||
@@ -78,29 +78,31 @@ def mock_subscription_request():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_lite_llm_error():
|
||||
mock_request = Request(scope={'type': 'http', 'state': {'user_id': 'mock_user'}})
|
||||
|
||||
mock_response = Response(
|
||||
status_code=500, json={'error': 'Internal Server Error'}, request=MagicMock()
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
with patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'):
|
||||
with patch('httpx.AsyncClient', return_value=mock_client):
|
||||
with pytest.raises(HTTPStatusError) as exc_info:
|
||||
await get_credits(mock_request)
|
||||
assert (
|
||||
exc_info.value.response.status_code
|
||||
== status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await get_credits('mock_user')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_success():
|
||||
mock_response = Response(
|
||||
status_code=200,
|
||||
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
|
||||
json={
|
||||
'user_info': {
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
}
|
||||
},
|
||||
request=MagicMock(),
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
@@ -109,24 +111,22 @@ async def test_get_credits_success():
|
||||
with (
|
||||
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
|
||||
patch('httpx.AsyncClient', return_value=mock_client),
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
},
|
||||
),
|
||||
):
|
||||
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.first.return_value = MagicMock(
|
||||
billing_margin=4
|
||||
)
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
result = await get_credits('mock_user')
|
||||
|
||||
result = await get_credits('mock_user')
|
||||
|
||||
assert isinstance(result, GetCreditsResponse)
|
||||
assert result.credits == Decimal(
|
||||
'74.50'
|
||||
) # 100.00 - 25.50 = 74.50 (no billing margin applied)
|
||||
mock_client.__aenter__.return_value.get.assert_called_once_with(
|
||||
'https://llm-proxy.app.all-hands.dev/user/info?user_id=mock_user',
|
||||
headers={'x-goog-api-key': None},
|
||||
)
|
||||
assert isinstance(result, GetCreditsResponse)
|
||||
assert result.credits == Decimal('74.50') # 100.00 - 25.50 = 74.50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -139,6 +139,9 @@ async def test_create_checkout_session_stripe_error(
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = uuid.uuid4()
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
pytest.raises(Exception, match='Stripe API Error'),
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
@@ -150,6 +153,10 @@ async def test_create_checkout_session_stripe_error(
|
||||
AsyncMock(side_effect=Exception('Stripe API Error')),
|
||||
),
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
@@ -175,6 +182,10 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org_id = uuid.uuid4()
|
||||
mock_org.id = mock_org_id
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
patch(
|
||||
@@ -183,6 +194,10 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
@@ -254,7 +269,6 @@ async def test_success_callback_stripe_incomplete():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
@@ -282,44 +296,33 @@ async def test_success_callback_success():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
mock_lite_llm_response = Response(
|
||||
status_code=200,
|
||||
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
|
||||
request=MagicMock(),
|
||||
)
|
||||
mock_lite_llm_update_response = Response(
|
||||
status_code=200, json={}, request=MagicMock()
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
},
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
|
||||
) as mock_update_budget,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_user_settings = MagicMock(billing_margin=None)
|
||||
mock_db_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_user_settings
|
||||
)
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete',
|
||||
amount_subtotal=2500,
|
||||
status='complete', amount_subtotal=2500, customer='mock_customer_id'
|
||||
) # $25.00 in cents
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__.return_value.get.return_value = (
|
||||
mock_lite_llm_response
|
||||
)
|
||||
mock_client_instance.__aenter__.return_value.post.return_value = (
|
||||
mock_lite_llm_update_response
|
||||
)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
response = await success_callback('test_session_id', mock_request)
|
||||
|
||||
assert response.status_code == 302
|
||||
@@ -329,18 +332,14 @@ async def test_success_callback_success():
|
||||
)
|
||||
|
||||
# Verify LiteLLM API calls
|
||||
mock_client_instance.__aenter__.return_value.get.assert_called_once()
|
||||
mock_client_instance.__aenter__.return_value.post.assert_called_once_with(
|
||||
'https://llm-proxy.app.all-hands.dev/user/update',
|
||||
headers={'x-goog-api-key': None},
|
||||
json={
|
||||
'user_id': 'mock_user',
|
||||
'max_budget': 125,
|
||||
}, # 100 + (25.00 from Stripe)
|
||||
mock_update_budget.assert_called_once_with(
|
||||
'mock_org_id',
|
||||
125.0, # 100 + (25.00 from Stripe)
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'completed'
|
||||
assert mock_billing_session.price == 25.0
|
||||
mock_db_session.merge.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@@ -354,27 +353,27 @@ async def test_success_callback_lite_llm_error():
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch('httpx.AsyncClient') as mock_client,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id',
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete', amount_total=2500
|
||||
status='complete', amount_subtotal=2500
|
||||
)
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.__aenter__.return_value.get.side_effect = Exception(
|
||||
'LiteLLM API Error'
|
||||
)
|
||||
mock_client.return_value = mock_client_instance
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback('test_session_id', mock_request)
|
||||
|
||||
@@ -398,7 +397,8 @@ async def test_cancel_callback_session_not_found():
|
||||
response = await cancel_callback('test_session_id', mock_request)
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
|
||||
response.headers['location']
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify no database updates occurred
|
||||
@@ -424,7 +424,8 @@ async def test_cancel_callback_success():
|
||||
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
|
||||
response.headers['location']
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
@@ -436,314 +437,67 @@ async def test_cancel_callback_success():
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_with_payment_method():
|
||||
"""Test has_payment_method returns True when user has a payment method."""
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[MagicMock()])),
|
||||
) as mock_list_payment_methods,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
MagicMock(stripe_customer_id='cus_test123')
|
||||
)
|
||||
|
||||
mock_has_payment_method = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
mock_has_payment_method,
|
||||
):
|
||||
result = await has_payment_method('mock_user')
|
||||
assert result is True
|
||||
mock_list_payment_methods.assert_called_once_with('cus_test123')
|
||||
mock_has_payment_method.assert_called_once_with('mock_user')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_without_payment_method():
|
||||
"""Test has_payment_method returns False when user has no payment method."""
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
) as mock_list_payment_methods,
|
||||
mock_has_payment_method = AsyncMock(return_value=False)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
mock_has_payment_method,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
MagicMock(stripe_customer_id='cus_test123')
|
||||
)
|
||||
|
||||
mock_has_payment_method.return_value = False
|
||||
result = await has_payment_method('mock_user')
|
||||
assert result is False
|
||||
mock_list_payment_methods.assert_called_once_with('cus_test123')
|
||||
mock_has_payment_method.assert_called_once_with('mock_user')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_success():
|
||||
"""Test successful subscription cancellation."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
async def test_create_customer_setup_session_success():
|
||||
"""Test successful creation of customer setup session."""
|
||||
mock_request = Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
'path': '/api/billing/create-customer-setup-session',
|
||||
'server': ('test.com', 80),
|
||||
'headers': [],
|
||||
}
|
||||
)
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
# Mock Stripe subscription response
|
||||
mock_stripe_subscription = MagicMock()
|
||||
mock_stripe_subscription.cancel_at_period_end = True
|
||||
mock_customer_info = {'customer_id': 'mock-customer-id', 'org_id': 'mock-org-id'}
|
||||
mock_session = MagicMock()
|
||||
mock_session.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_create = AsyncMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Subscription.modify_async',
|
||||
AsyncMock(return_value=mock_stripe_subscription),
|
||||
) as mock_stripe_modify,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function
|
||||
result = await cancel_subscription('test_user')
|
||||
|
||||
# Verify Stripe API was called
|
||||
mock_stripe_modify.assert_called_once_with(
|
||||
'sub_test123', cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Verify database was updated
|
||||
assert mock_subscription_access.cancelled_at is not None
|
||||
mock_session.merge.assert_called_once_with(mock_subscription_access)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify response
|
||||
assert result.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_no_active_subscription():
|
||||
"""Test cancellation when no active subscription exists."""
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Setup mock session with no subscription found
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'No active subscription found' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_missing_stripe_id():
|
||||
"""Test cancellation when subscription has no Stripe ID."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock subscription without Stripe ID
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id=None, # Missing Stripe ID
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'missing Stripe subscription ID' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_subscription_stripe_error():
|
||||
"""Test cancellation when Stripe API fails."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'stripe.Subscription.modify_async',
|
||||
AsyncMock(side_effect=stripe.StripeError('API Error')),
|
||||
'integrations.stripe_service.find_or_create_customer_by_user_id',
|
||||
AsyncMock(return_value=mock_customer_info),
|
||||
),
|
||||
):
|
||||
# Setup mock session
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await cancel_subscription('test_user')
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to cancel subscription' in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_duplicate_prevention(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription when user already has active subscription raises error."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
# Mock active subscription
|
||||
mock_subscription_access = SubscriptionAccess(
|
||||
id=1,
|
||||
status='ACTIVE',
|
||||
user_id='test_user',
|
||||
start_at=datetime.now(UTC),
|
||||
end_at=datetime.now(UTC),
|
||||
amount_paid=2000,
|
||||
stripe_invoice_payment_id='pi_test',
|
||||
stripe_subscription_id='sub_test123',
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return existing active subscription
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
|
||||
result = await create_customer_setup_session(mock_request, 'mock_user')
|
||||
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert (
|
||||
'user already has an active subscription'
|
||||
in str(exc_info.value.detail).lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_allows_after_cancellation(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription is allowed when previous subscription was cancelled."""
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session_obj.id = 'test_session_id'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer',
|
||||
AsyncMock(return_value='cus_test123'),
|
||||
),
|
||||
patch(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(return_value=mock_session_obj),
|
||||
),
|
||||
patch(
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session - the query should return None because cancelled subscriptions are filtered out
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
assert isinstance(result, billing.CreateBillingSessionResponse)
|
||||
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_success_no_existing(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test successful subscription creation when no existing subscription."""
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session_obj.id = 'test_session_id'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer',
|
||||
AsyncMock(return_value='cus_test123'),
|
||||
),
|
||||
patch(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(return_value=mock_session_obj),
|
||||
),
|
||||
patch(
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return no existing subscription
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_subscription_request, user_id='test_user'
|
||||
# Verify Stripe session creation parameters
|
||||
mock_create.assert_called_once_with(
|
||||
customer='mock-customer-id',
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url='http://test.com/?free_credits=success',
|
||||
cancel_url='http://test.com/',
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
|
||||
|
||||
@@ -3,14 +3,29 @@ Tests for ConversationCallbackProcessor and ConversationCallback models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from storage.conversation_callback import (
|
||||
CallbackStatus,
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
|
||||
# Import the actual StoredConversationMetadata from OpenHands core
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
|
||||
# Mock the lazy import to return the actual class
|
||||
with patch(
|
||||
'storage.stored_conversation_metadata.StoredConversationMetadata',
|
||||
StoredConversationMetadata,
|
||||
):
|
||||
from storage.conversation_callback import (
|
||||
CallbackStatus,
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
|
||||
@@ -80,15 +95,22 @@ class TestConversationCallback:
|
||||
"""Create a test conversation metadata record."""
|
||||
with session_maker() as session:
|
||||
metadata = StoredConversationMetadata(
|
||||
conversation_id='test_conversation_123', user_id='test_user_456'
|
||||
conversation_id='test_conversation_123'
|
||||
)
|
||||
metadata_saas = StoredConversationMetadataSaas(
|
||||
conversation_id='test_conversation_123',
|
||||
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
)
|
||||
session.add(metadata)
|
||||
session.add(metadata_saas)
|
||||
session.commit()
|
||||
session.refresh(metadata)
|
||||
yield metadata
|
||||
|
||||
# Cleanup
|
||||
session.delete(metadata)
|
||||
session.delete(metadata_saas)
|
||||
session.commit()
|
||||
|
||||
def test_callback_creation(self, conversation_metadata, session_maker):
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from unittest import TestCase, mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from integrations.github.github_view import GithubFactory, get_oh_labels
|
||||
import pytest
|
||||
from integrations.github.github_view import GithubFactory, GithubIssue, get_oh_labels
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.types import UserData
|
||||
|
||||
|
||||
class TestGithubLabels(TestCase):
|
||||
@@ -75,3 +78,132 @@ class TestGithubCommentCaseInsensitivity(TestCase):
|
||||
self.assertTrue(GithubFactory.is_issue_comment(message_lower))
|
||||
self.assertTrue(GithubFactory.is_issue_comment(message_upper))
|
||||
self.assertTrue(GithubFactory.is_issue_comment(message_mixed))
|
||||
|
||||
|
||||
class TestGithubV1ConversationRouting(TestCase):
|
||||
"""Test V1 conversation routing logic in GitHub integration."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
# Create a proper UserData instance instead of MagicMock
|
||||
user_data = UserData(
|
||||
user_id=123, username='testuser', keycloak_user_id='test-keycloak-id'
|
||||
)
|
||||
|
||||
# Create a mock raw_payload
|
||||
raw_payload = Message(
|
||||
source=SourceType.GITHUB,
|
||||
message={
|
||||
'payload': {
|
||||
'action': 'opened',
|
||||
'issue': {'number': 123},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self.github_issue = GithubIssue(
|
||||
user_info=user_data,
|
||||
full_repo_name='test/repo',
|
||||
issue_number=123,
|
||||
installation_id=456,
|
||||
conversation_id='test-conversation-id',
|
||||
should_extract=True,
|
||||
send_summary_instruction=False,
|
||||
is_public_repo=True,
|
||||
raw_payload=raw_payload,
|
||||
uuid='test-uuid',
|
||||
title='Test Issue',
|
||||
description='Test issue description',
|
||||
previous_comments=[],
|
||||
v1=False,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.github.github_view.get_user_v1_enabled_setting')
|
||||
@patch.object(GithubIssue, '_create_v0_conversation')
|
||||
@patch.object(GithubIssue, '_create_v1_conversation')
|
||||
async def test_create_new_conversation_routes_to_v0_when_disabled(
|
||||
self, mock_create_v1, mock_create_v0, mock_get_v1_setting
|
||||
):
|
||||
"""Test that conversation creation routes to V0 when v1_enabled is False."""
|
||||
# Mock v1_enabled as False
|
||||
mock_get_v1_setting.return_value = False
|
||||
mock_create_v0.return_value = None
|
||||
mock_create_v1.return_value = None
|
||||
|
||||
# Mock parameters
|
||||
jinja_env = MagicMock()
|
||||
git_provider_tokens = MagicMock()
|
||||
conversation_metadata = MagicMock()
|
||||
|
||||
# Call the method
|
||||
await self.github_issue.create_new_conversation(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
|
||||
# Verify V0 was called and V1 was not
|
||||
mock_create_v0.assert_called_once_with(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
mock_create_v1.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.github.github_view.get_user_v1_enabled_setting')
|
||||
@patch.object(GithubIssue, '_create_v0_conversation')
|
||||
@patch.object(GithubIssue, '_create_v1_conversation')
|
||||
async def test_create_new_conversation_routes_to_v1_when_enabled(
|
||||
self, mock_create_v1, mock_create_v0, mock_get_v1_setting
|
||||
):
|
||||
"""Test that conversation creation routes to V1 when v1_enabled is True."""
|
||||
# Mock v1_enabled as True
|
||||
mock_get_v1_setting.return_value = True
|
||||
mock_create_v0.return_value = None
|
||||
mock_create_v1.return_value = None
|
||||
|
||||
# Mock parameters
|
||||
jinja_env = MagicMock()
|
||||
git_provider_tokens = MagicMock()
|
||||
conversation_metadata = MagicMock()
|
||||
|
||||
# Call the method
|
||||
await self.github_issue.create_new_conversation(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
|
||||
# Verify V1 was called and V0 was not
|
||||
mock_create_v1.assert_called_once_with(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
mock_create_v0.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.github.github_view.get_user_v1_enabled_setting')
|
||||
@patch.object(GithubIssue, '_create_v0_conversation')
|
||||
@patch.object(GithubIssue, '_create_v1_conversation')
|
||||
async def test_create_new_conversation_fallback_on_v1_setting_error(
|
||||
self, mock_create_v1, mock_create_v0, mock_get_v1_setting
|
||||
):
|
||||
"""Test that conversation creation falls back to V0 when _create_v1_conversation fails."""
|
||||
# Mock v1_enabled as True so V1 is attempted
|
||||
mock_get_v1_setting.return_value = True
|
||||
# Mock _create_v1_conversation to raise an exception
|
||||
mock_create_v1.side_effect = Exception('V1 conversation creation failed')
|
||||
mock_create_v0.return_value = None
|
||||
|
||||
# Mock parameters
|
||||
jinja_env = MagicMock()
|
||||
git_provider_tokens = MagicMock()
|
||||
conversation_metadata = MagicMock()
|
||||
|
||||
# Call the method
|
||||
await self.github_issue.create_new_conversation(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
|
||||
# Verify V1 was attempted first, then V0 was called as fallback
|
||||
mock_create_v1.assert_called_once_with(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
mock_create_v0.assert_called_once_with(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
|
||||
650
enterprise/tests/unit/test_lite_llm_manager.py
Normal file
650
enterprise/tests/unit/test_lite_llm_manager.py
Normal file
@@ -0,0 +1,650 @@
|
||||
"""
|
||||
Unit tests for LiteLlmManager class.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
from server.constants import (
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
|
||||
class TestLiteLlmManager:
|
||||
"""Test cases for LiteLlmManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create a mock Settings object."""
|
||||
settings = Settings()
|
||||
settings.agent = 'TestAgent'
|
||||
settings.llm_model = 'test-model'
|
||||
settings.llm_api_key = SecretStr('test-key')
|
||||
settings.llm_base_url = 'http://test.com'
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_settings(self):
|
||||
"""Create a mock UserSettings object."""
|
||||
user_settings = UserSettings()
|
||||
user_settings.agent = 'TestAgent'
|
||||
user_settings.llm_model = 'test-model'
|
||||
user_settings.llm_api_key = SecretStr('test-key')
|
||||
user_settings.llm_base_url = 'http://test.com'
|
||||
return user_settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_client(self):
|
||||
"""Create a mock HTTP client."""
|
||||
client = AsyncMock(spec=httpx.AsyncClient)
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Create a mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.text = 'Success'
|
||||
response.json.return_value = {'key': 'test-api-key'}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_team_response(self):
|
||||
"""Create a mock team response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.json.return_value = {
|
||||
'team_memberships': [
|
||||
{
|
||||
'user_id': 'test-user-id',
|
||||
'team_id': 'test-org-id',
|
||||
'max_budget': 100.0,
|
||||
}
|
||||
]
|
||||
}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_response(self):
|
||||
"""Create a mock user response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.json.return_value = {
|
||||
'user_info': {
|
||||
'max_budget': 50.0,
|
||||
'spend': 10.0,
|
||||
}
|
||||
}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_key_info_response(self):
|
||||
"""Create a mock key info response."""
|
||||
response = MagicMock()
|
||||
response.is_success = True
|
||||
response.status_code = 200
|
||||
response.json.return_value = {
|
||||
'info': {
|
||||
'max_budget': 100.0,
|
||||
'spend': 25.0,
|
||||
}
|
||||
}
|
||||
response.raise_for_status = MagicMock()
|
||||
return response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_missing_config(self, mock_settings):
|
||||
"""Test create_entries when LiteLLM config is missing."""
|
||||
with patch.dict(os.environ, {'LITE_LLM_API_KEY': '', 'LITE_LLM_API_URL': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_local_deployment(self, mock_settings):
|
||||
"""Test create_entries in local deployment mode."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': '1'}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_cloud_deployment(self, mock_settings, mock_response):
|
||||
"""Test create_entries in cloud deployment mode."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert (
|
||||
result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
)
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify API calls were made
|
||||
assert (
|
||||
mock_client.post.call_count == 4
|
||||
) # create_team, create_user, add_user_to_team, generate_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_missing_config(self, mock_user_settings):
|
||||
"""Test migrate_entries when LiteLLM config is missing."""
|
||||
with patch.dict(os.environ, {'LITE_LLM_API_KEY': '', 'LITE_LLM_API_URL': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_local_deployment(self, mock_user_settings):
|
||||
"""Test migrate_entries in local deployment mode."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': '1'}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_no_user_found(self, mock_user_settings):
|
||||
"""Test migrate_entries when user is not found."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
# Mock the _get_user method directly to return None
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_user', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
mock_get_user.return_value = None
|
||||
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_already_migrated(
|
||||
self, mock_user_settings, mock_user_response
|
||||
):
|
||||
"""Test migrate_entries when user is already migrated (no max_budget)."""
|
||||
mock_user_response.json.return_value = {
|
||||
'user_info': {
|
||||
'max_budget': None, # Already migrated
|
||||
'spend': 10.0,
|
||||
}
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.get.return_value = mock_user_response
|
||||
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_successful_migration(
|
||||
self, mock_user_settings, mock_user_response, mock_response
|
||||
):
|
||||
"""Test successful migrate_entries operation."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.get.return_value = mock_user_response
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify migration steps were called
|
||||
assert (
|
||||
mock_client.post.call_count == 4
|
||||
) # create_team, update_user, add_user_to_team, update_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_team_and_users_budget_missing_config(self):
|
||||
"""Test update_team_and_users_budget when LiteLLM config is missing."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
# Should not raise an exception, just return early
|
||||
await LiteLlmManager.update_team_and_users_budget('test-team-id', 100.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_team_and_users_budget_successful(
|
||||
self, mock_team_response, mock_response
|
||||
):
|
||||
"""Test successful update_team_and_users_budget operation."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_client.get.return_value = mock_team_response
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
'test-team-id', 100.0
|
||||
)
|
||||
|
||||
# Verify update_team and update_user_in_team were called
|
||||
assert (
|
||||
mock_client.post.call_count == 2
|
||||
) # update_team, update_user_in_team
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _create_team operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_team(
|
||||
mock_http_client, 'test-alias', 'test-team-id', 100.0
|
||||
)
|
||||
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/team/new' in call_args[0]
|
||||
assert call_args[1]['json']['team_id'] == 'test-team-id'
|
||||
assert call_args[1]['json']['team_alias'] == 'test-alias'
|
||||
assert call_args[1]['json']['max_budget'] == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_already_exists(self, mock_http_client):
|
||||
"""Test _create_team when team already exists."""
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 400
|
||||
error_response.text = 'Team already exists. Please use a different team id'
|
||||
mock_http_client.post.return_value = error_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch.object(
|
||||
LiteLlmManager, '_update_team', new_callable=AsyncMock
|
||||
) as mock_update:
|
||||
await LiteLlmManager._create_team(
|
||||
mock_http_client, 'test-alias', 'test-team-id', 100.0
|
||||
)
|
||||
|
||||
mock_update.assert_called_once_with(
|
||||
mock_http_client, 'test-team-id', 'test-alias', 100.0
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_error(self, mock_http_client):
|
||||
"""Test _create_team with unexpected error."""
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 500
|
||||
error_response.text = 'Internal server error'
|
||||
error_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
'Server error', request=MagicMock(), response=error_response
|
||||
)
|
||||
mock_http_client.post.return_value = error_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await LiteLlmManager._create_team(
|
||||
mock_http_client, 'test-alias', 'test-team-id', 100.0
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_team_success(self, mock_http_client, mock_team_response):
|
||||
"""Test successful _get_team operation."""
|
||||
mock_http_client.get.return_value = mock_team_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._get_team(
|
||||
mock_http_client, 'test-team-id'
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert 'team_memberships' in result
|
||||
mock_http_client.get.assert_called_once_with(
|
||||
'http://test.com/team/info?team_id=test-team-id'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _create_user operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/user/new' in call_args[0]
|
||||
assert call_args[1]['json']['user_email'] == 'test@example.com'
|
||||
assert call_args[1]['json']['user_id'] == 'test-user-id'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_duplicate_email(self, mock_http_client, mock_response):
|
||||
"""Test _create_user with duplicate email handling."""
|
||||
# First call fails with duplicate email
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 400
|
||||
error_response.text = 'duplicate email'
|
||||
|
||||
# Second call succeeds
|
||||
mock_http_client.post.side_effect = [error_response, mock_response]
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._create_user(
|
||||
mock_http_client, 'test@example.com', 'test-user-id'
|
||||
)
|
||||
|
||||
assert mock_http_client.post.call_count == 2
|
||||
# Second call should have None email
|
||||
second_call_args = mock_http_client.post.call_args_list[1]
|
||||
assert second_call_args[1]['json']['user_email'] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_key_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _generate_key operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._generate_key(
|
||||
mock_http_client,
|
||||
'test-user-id',
|
||||
'test-team-id',
|
||||
'test-alias',
|
||||
{'test': 'metadata'},
|
||||
)
|
||||
|
||||
assert result == 'test-api-key'
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/key/generate' in call_args[0]
|
||||
assert call_args[1]['json']['user_id'] == 'test-user-id'
|
||||
assert call_args[1]['json']['team_id'] == 'test-team-id'
|
||||
assert call_args[1]['json']['key_alias'] == 'test-alias'
|
||||
assert call_args[1]['json']['metadata'] == {'test': 'metadata'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_key_info_success(self, mock_http_client, mock_key_info_response):
|
||||
"""Test successful _get_key_info operation."""
|
||||
mock_http_client.get.return_value = mock_key_info_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch('storage.user_store.UserStore') as mock_user_store:
|
||||
# Mock user with org member
|
||||
mock_user = MagicMock()
|
||||
mock_org_member = MagicMock()
|
||||
mock_org_member.org_id = 'test-ord-id'
|
||||
mock_org_member.llm_api_key = 'test-api-key'
|
||||
mock_user.org_members = [mock_org_member]
|
||||
mock_user_store.get_user_by_id.return_value = mock_user
|
||||
|
||||
result = await LiteLlmManager._get_key_info(
|
||||
mock_http_client, 'test-ord-id', 'test-user-id'
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result['key_max_budget'] == 100.0
|
||||
assert result['key_spend'] == 25.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_key_info_no_user(self, mock_http_client):
|
||||
"""Test _get_key_info when user is not found."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
with patch('storage.user_store.UserStore') as mock_user_store:
|
||||
mock_user_store.get_user_by_id.return_value = None
|
||||
|
||||
result = await LiteLlmManager._get_key_info(
|
||||
mock_http_client, 'test-ord-id', 'test-user-id'
|
||||
)
|
||||
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _delete_key operation."""
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
await LiteLlmManager._delete_key(mock_http_client, 'test-key-id')
|
||||
|
||||
mock_http_client.post.assert_called_once()
|
||||
call_args = mock_http_client.post.call_args
|
||||
assert 'http://test.com/key/delete' in call_args[0]
|
||||
assert call_args[1]['json']['keys'] == ['test-key-id']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key_not_found(self, mock_http_client):
|
||||
"""Test _delete_key when key is not found (404 error)."""
|
||||
error_response = MagicMock()
|
||||
error_response.is_success = False
|
||||
error_response.status_code = 404
|
||||
error_response.text = 'Key not found'
|
||||
mock_http_client.post.return_value = error_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
# Should not raise an exception for 404
|
||||
await LiteLlmManager._delete_key(mock_http_client, 'test-key-id')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_http_client_decorator(self):
|
||||
"""Test the with_http_client decorator functionality."""
|
||||
|
||||
# Create a mock internal function
|
||||
async def mock_internal_fn(client, arg1, arg2, kwarg1=None):
|
||||
return f'client={type(client).__name__}, arg1={arg1}, arg2={arg2}, kwarg1={kwarg1}'
|
||||
|
||||
# Apply the decorator
|
||||
decorated_fn = LiteLlmManager.with_http_client(mock_internal_fn)
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
result = await decorated_fn('test1', 'test2', kwarg1='test3')
|
||||
|
||||
# Verify the client was injected as the first argument
|
||||
assert 'client=AsyncMock' in result
|
||||
assert 'arg1=test1' in result
|
||||
assert 'arg2=test2' in result
|
||||
assert 'kwarg1=test3' in result
|
||||
|
||||
def test_public_methods_exist(self):
|
||||
"""Test that all public wrapper methods exist and are properly decorated."""
|
||||
public_methods = [
|
||||
'create_team',
|
||||
'get_team',
|
||||
'update_team',
|
||||
'create_user',
|
||||
'get_user',
|
||||
'update_user',
|
||||
'delete_user',
|
||||
'add_user_to_team',
|
||||
'get_user_team_info',
|
||||
'update_user_in_team',
|
||||
'generate_key',
|
||||
'get_key_info',
|
||||
'delete_key',
|
||||
]
|
||||
|
||||
for method_name in public_methods:
|
||||
assert hasattr(LiteLlmManager, method_name)
|
||||
method = getattr(LiteLlmManager, method_name)
|
||||
assert callable(method)
|
||||
# The methods are created by the with_http_client decorator, so they're functions
|
||||
# We can verify they exist and are callable, which is the important part
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_missing_config_all_methods(self):
|
||||
"""Test that all methods handle missing configuration gracefully."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
mock_client = AsyncMock()
|
||||
|
||||
# Test all private methods that check for config
|
||||
await LiteLlmManager._create_team(
|
||||
mock_client, 'alias', 'team_id', 100.0
|
||||
)
|
||||
await LiteLlmManager._update_team(
|
||||
mock_client, 'team_id', 'alias', 100.0
|
||||
)
|
||||
await LiteLlmManager._create_user(mock_client, 'email', 'user_id')
|
||||
await LiteLlmManager._update_user(mock_client, 'user_id')
|
||||
await LiteLlmManager._delete_user(mock_client, 'user_id')
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
mock_client, 'user_id', 'team_id', 100.0
|
||||
)
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
mock_client, 'user_id', 'team_id', 100.0
|
||||
)
|
||||
await LiteLlmManager._delete_key(mock_client, 'key_id')
|
||||
|
||||
result1 = await LiteLlmManager._get_team(mock_client, 'team_id')
|
||||
result2 = await LiteLlmManager._get_user(mock_client, 'user_id')
|
||||
result3 = await LiteLlmManager._generate_key(
|
||||
mock_client, 'user_id', 'team_id', 'alias', {}
|
||||
)
|
||||
result4 = await LiteLlmManager._get_user_team_info(
|
||||
mock_client, 'user_id', 'team_id'
|
||||
)
|
||||
result5 = await LiteLlmManager._get_key_info(
|
||||
mock_client, 'test-ord-id', 'user_id'
|
||||
)
|
||||
|
||||
# Methods that return None when config is missing
|
||||
assert result1 is None
|
||||
assert result2 is None
|
||||
assert result3 is None
|
||||
assert result4 is None
|
||||
assert result5 is None
|
||||
|
||||
# Verify no HTTP calls were made
|
||||
mock_client.get.assert_not_called()
|
||||
mock_client.post.assert_not_called()
|
||||
70
enterprise/tests/unit/test_models.py
Normal file
70
enterprise/tests/unit/test_models.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Test that the models are correctly defined.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def test_user_model(session_maker):
|
||||
"""Test that the User model works correctly."""
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test_org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
# Create a test user
|
||||
test_user_id = uuid4()
|
||||
user = User(id=test_user_id, current_org_id=org.id, language='en')
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
# Create org_member relationship
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=1,
|
||||
llm_api_key='test-api-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
|
||||
# Query the user
|
||||
queried_user = session.query(User).filter(User.id == test_user_id).first()
|
||||
assert queried_user is not None
|
||||
assert queried_user.language == 'en'
|
||||
|
||||
# Query the org
|
||||
queried_org = session.query(Org).filter(Org.id == org.id).first()
|
||||
assert queried_org is not None
|
||||
assert queried_org.name == 'test_org'
|
||||
|
||||
# Query the org_member relationship
|
||||
queried_org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org.id, OrgMember.user_id == user.id)
|
||||
.first()
|
||||
)
|
||||
assert queried_org_member is not None
|
||||
assert queried_org_member.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
253
enterprise/tests/unit/test_org_member_store.py
Normal file
253
enterprise/tests/unit/test_org_member_store.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock the database module before importing OrgMemberStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
|
||||
def test_get_org_members(session_maker):
|
||||
# Test getting org_members by org ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user1 = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
user2 = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user1, user2, role])
|
||||
session.flush()
|
||||
|
||||
org_member1 = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user1.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-1',
|
||||
status='active',
|
||||
)
|
||||
org_member2 = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user2.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-2',
|
||||
status='active',
|
||||
)
|
||||
session.add_all([org_member1, org_member2])
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_members = OrgMemberStore.get_org_members(org_id)
|
||||
assert len(org_members) == 2
|
||||
api_keys = [om.llm_api_key.get_secret_value() for om in org_members]
|
||||
assert 'test-key-1' in api_keys
|
||||
assert 'test-key-2' in api_keys
|
||||
|
||||
|
||||
def test_get_user_orgs(session_maker):
|
||||
# Test getting org_members by user ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org1 = Org(name='test-org-1')
|
||||
org2 = Org(name='test-org-2')
|
||||
session.add_all([org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org1.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member1 = OrgMember(
|
||||
org_id=org1.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-1',
|
||||
status='active',
|
||||
)
|
||||
org_member2 = OrgMember(
|
||||
org_id=org2.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-2',
|
||||
status='active',
|
||||
)
|
||||
session.add_all([org_member1, org_member2])
|
||||
session.commit()
|
||||
user_id = user.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_members = OrgMemberStore.get_user_orgs(user_id)
|
||||
assert len(org_members) == 2
|
||||
api_keys = [ou.llm_api_key.get_secret_value() for ou in org_members]
|
||||
assert 'test-key-1' in api_keys
|
||||
assert 'test-key-2' in api_keys
|
||||
|
||||
|
||||
def test_get_org_member(session_maker):
|
||||
# Test getting org_member by org and user ID
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
|
||||
assert retrieved_org_member is not None
|
||||
assert retrieved_org_member.org_id == org_id
|
||||
assert retrieved_org_member.user_id == user_id
|
||||
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key'
|
||||
|
||||
|
||||
def test_add_user_to_org(session_maker):
|
||||
# Test adding a user to an org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
role_id = role.id
|
||||
|
||||
# Test creation
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
org_member = OrgMemberStore.add_user_to_org(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
llm_api_key='new-test-key',
|
||||
status='active',
|
||||
)
|
||||
|
||||
assert org_member is not None
|
||||
assert org_member.org_id == org_id
|
||||
assert org_member.user_id == user_id
|
||||
assert org_member.role_id == role_id
|
||||
assert org_member.llm_api_key.get_secret_value() == 'new-test-key'
|
||||
assert org_member.status == 'active'
|
||||
|
||||
|
||||
def test_update_user_role_in_org(session_maker):
|
||||
# Test updating user role in org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role1 = Role(name='admin', rank=1)
|
||||
role2 = Role(name='user', rank=2)
|
||||
session.add_all([user, role1, role2])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role1.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
role2_id = role2.id
|
||||
|
||||
# Test update
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
updated_org_member = OrgMemberStore.update_user_role_in_org(
|
||||
org_id=org_id, user_id=user_id, role_id=role2_id, status='inactive'
|
||||
)
|
||||
|
||||
assert updated_org_member is not None
|
||||
assert updated_org_member.role_id == role2_id
|
||||
assert updated_org_member.status == 'inactive'
|
||||
|
||||
|
||||
def test_update_user_role_in_org_not_found(session_maker):
|
||||
# Test updating org_member that doesn't exist
|
||||
from uuid import uuid4
|
||||
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
updated_org_member = OrgMemberStore.update_user_role_in_org(
|
||||
org_id=uuid4(), user_id=99999, role_id=1
|
||||
)
|
||||
assert updated_org_member is None
|
||||
|
||||
|
||||
def test_remove_user_from_org(session_maker):
|
||||
# Test removing a user from an org
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
|
||||
# Test removal
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
result = OrgMemberStore.remove_user_from_org(org_id, user_id)
|
||||
assert result is True
|
||||
|
||||
# Verify it's removed
|
||||
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
|
||||
assert retrieved_org_member is None
|
||||
|
||||
|
||||
def test_remove_user_from_org_not_found(session_maker):
|
||||
# Test removing user from org that doesn't exist
|
||||
from uuid import uuid4
|
||||
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
result = OrgMemberStore.remove_user_from_org(uuid4(), 99999)
|
||||
assert result is False
|
||||
197
enterprise/tests/unit/test_org_store.py
Normal file
197
enterprise/tests/unit/test_org_store.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
# Mock the database module before importing OrgStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_api():
|
||||
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
|
||||
api_url_patch = patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
|
||||
)
|
||||
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
|
||||
client_patch = patch('httpx.AsyncClient')
|
||||
|
||||
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.patch.return_value = (
|
||||
mock_response
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_get_org_by_id(session_maker, mock_litellm_api):
|
||||
# Test getting org by ID
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_org_by_id(org_id)
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.id == org_id
|
||||
assert retrieved_org.name == 'test-org'
|
||||
|
||||
|
||||
def test_get_org_by_id_not_found(session_maker):
|
||||
# Test getting org by ID when it doesn't exist
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
non_existent_id = uuid.uuid4()
|
||||
retrieved_org = OrgStore.get_org_by_id(non_existent_id)
|
||||
assert retrieved_org is None
|
||||
|
||||
|
||||
def test_list_orgs(session_maker, mock_litellm_api):
|
||||
# Test listing all orgs
|
||||
with session_maker() as session:
|
||||
# Create test orgs
|
||||
org1 = Org(name='test-org-1')
|
||||
org2 = Org(name='test-org-2')
|
||||
session.add_all([org1, org2])
|
||||
session.commit()
|
||||
|
||||
# Test listing
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
orgs = OrgStore.list_orgs()
|
||||
assert len(orgs) >= 2
|
||||
org_names = [org.name for org in orgs]
|
||||
assert 'test-org-1' in org_names
|
||||
assert 'test-org-2' in org_names
|
||||
|
||||
|
||||
def test_update_org(session_maker, mock_litellm_api):
|
||||
# Test updating org details
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org', agent='CodeActAgent')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Test update
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
updated_org = OrgStore.update_org(
|
||||
org_id=org_id, kwargs={'name': 'updated-org', 'agent': 'PlannerAgent'}
|
||||
)
|
||||
|
||||
assert updated_org is not None
|
||||
assert updated_org.name == 'updated-org'
|
||||
assert updated_org.agent == 'PlannerAgent'
|
||||
|
||||
|
||||
def test_update_org_not_found(session_maker):
|
||||
# Test updating org that doesn't exist
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
from uuid import uuid4
|
||||
|
||||
updated_org = OrgStore.update_org(
|
||||
org_id=uuid4(), kwargs={'name': 'updated-org'}
|
||||
)
|
||||
assert updated_org is None
|
||||
|
||||
|
||||
def test_create_org(session_maker, mock_litellm_api):
|
||||
# Test creating a new org
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
org = OrgStore.create_org(kwargs={'name': 'new-org', 'agent': 'CodeActAgent'})
|
||||
|
||||
assert org is not None
|
||||
assert org.name == 'new-org'
|
||||
assert org.agent == 'CodeActAgent'
|
||||
assert org.id is not None
|
||||
|
||||
|
||||
def test_get_org_by_name(session_maker, mock_litellm_api):
|
||||
# Test getting org by name
|
||||
with session_maker() as session:
|
||||
# Create a test org
|
||||
org = Org(name='test-org-by-name')
|
||||
session.add(org)
|
||||
session.commit()
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_org_by_name('test-org-by-name')
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.name == 'test-org-by-name'
|
||||
|
||||
|
||||
def test_get_current_org_from_keycloak_user_id(session_maker, mock_litellm_api):
|
||||
# Test getting current org from user ID
|
||||
test_user_id = uuid.uuid4()
|
||||
with session_maker() as session:
|
||||
# Create test data
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
from storage.user import User
|
||||
|
||||
user = User(id=test_user_id, current_org_id=org.id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Test retrieval
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
):
|
||||
retrieved_org = OrgStore.get_current_org_from_keycloak_user_id(
|
||||
str(test_user_id)
|
||||
)
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.name == 'test-org'
|
||||
|
||||
|
||||
def test_get_kwargs_from_settings():
|
||||
# Test extracting org kwargs from settings
|
||||
settings = Settings(
|
||||
language='es',
|
||||
agent='CodeActAgent',
|
||||
llm_model='gpt-4',
|
||||
llm_api_key=SecretStr('test-key'),
|
||||
enable_sound_notifications=True,
|
||||
)
|
||||
|
||||
kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
|
||||
# Should only include fields that exist in Org model
|
||||
assert 'agent' in kwargs
|
||||
assert 'default_llm_model' in kwargs
|
||||
assert kwargs['agent'] == 'CodeActAgent'
|
||||
assert kwargs['default_llm_model'] == 'gpt-4'
|
||||
# Should not include fields that don't exist in Org model
|
||||
assert 'language' not in kwargs # language is not in Org model
|
||||
assert 'llm_api_key' not in kwargs
|
||||
assert 'llm_model' not in kwargs
|
||||
assert 'enable_sound_notifications' not in kwargs
|
||||
@@ -1,32 +1,15 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.github.github_view import get_user_proactive_conversation_setting
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from integrations.github.github_view import get_user_proactive_conversation_setting
|
||||
from storage.org import Org
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
# Mock the call_sync_from_async function to return the result of the function directly
|
||||
def mock_call_sync_from_async(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
filter = MagicMock()
|
||||
|
||||
# Mock the context manager behavior
|
||||
session.__enter__.return_value = session
|
||||
|
||||
session.query.return_value = query
|
||||
query.filter.return_value = filter
|
||||
|
||||
return session, query, filter
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_no_user_id():
|
||||
"""Test that the function returns False when no user ID is provided."""
|
||||
with patch(
|
||||
@@ -42,75 +25,82 @@ async def test_get_user_proactive_conversation_setting_no_user_id():
|
||||
assert await get_user_proactive_conversation_setting(None) is False
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_not_found(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_not_found():
|
||||
"""Test that False is returned when the user is not found."""
|
||||
session, query, filter = mock_session
|
||||
filter.first.return_value = None
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=None,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_none(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_none():
|
||||
"""Test that False is returned when the user setting is None."""
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = None
|
||||
filter.first.return_value = user_settings
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = None
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_true(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_true():
|
||||
"""Test that True is returned when the user setting is True and the global setting is True."""
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = True
|
||||
filter.first.return_value = user_settings
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = True
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is True
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_false(mock_session):
|
||||
async def test_get_user_proactive_conversation_setting_user_setting_false():
|
||||
"""Test that False is returned when the user setting is False, regardless of global setting."""
|
||||
session, query, filter = mock_session
|
||||
user_settings = MagicMock(spec=UserSettings)
|
||||
user_settings.enable_proactive_conversation_starters = False
|
||||
filter.first.return_value = user_settings
|
||||
mock_org = MagicMock(spec=Org)
|
||||
mock_org.enable_proactive_conversation_starters = False
|
||||
|
||||
with patch('integrations.github.github_view.session_maker', return_value=session):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
|
||||
True,
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
):
|
||||
with patch(
|
||||
'integrations.github.github_view.call_sync_from_async',
|
||||
side_effect=mock_call_sync_from_async,
|
||||
):
|
||||
assert await get_user_proactive_conversation_setting('user-id') is False
|
||||
assert (
|
||||
await get_user_proactive_conversation_setting(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
83
enterprise/tests/unit/test_role_store.py
Normal file
83
enterprise/tests/unit/test_role_store.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock the database module before importing RoleStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
def test_get_role_by_id(session_maker):
|
||||
# Test getting role by ID
|
||||
with session_maker() as session:
|
||||
# Create a test role
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_id(role_id)
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'admin'
|
||||
|
||||
|
||||
def test_get_role_by_id_not_found(session_maker):
|
||||
# Test getting role by ID when it doesn't exist
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_id(99999)
|
||||
assert retrieved_role is None
|
||||
|
||||
|
||||
def test_get_role_by_name(session_maker):
|
||||
# Test getting role by name
|
||||
with session_maker() as session:
|
||||
# Create a test role
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
role_id = role.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_name('admin')
|
||||
assert retrieved_role is not None
|
||||
assert retrieved_role.id == role_id
|
||||
assert retrieved_role.name == 'admin'
|
||||
|
||||
|
||||
def test_get_role_by_name_not_found(session_maker):
|
||||
# Test getting role by name when it doesn't exist
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
retrieved_role = RoleStore.get_role_by_name('nonexistent')
|
||||
assert retrieved_role is None
|
||||
|
||||
|
||||
def test_list_roles(session_maker):
|
||||
# Test listing all roles
|
||||
with session_maker() as session:
|
||||
# Create test roles
|
||||
role1 = Role(name='admin', rank=1)
|
||||
role2 = Role(name='user', rank=2)
|
||||
session.add_all([role1, role2])
|
||||
session.commit()
|
||||
|
||||
# Test listing
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
roles = RoleStore.list_roles()
|
||||
assert len(roles) >= 2
|
||||
role_names = [role.name for role in roles]
|
||||
assert 'admin' in role_names
|
||||
assert 'user' in role_names
|
||||
|
||||
|
||||
def test_create_role(session_maker):
|
||||
# Test creating a new role
|
||||
with patch('storage.role_store.session_maker', session_maker):
|
||||
role = RoleStore.create_role(name='moderator', rank=2)
|
||||
|
||||
assert role is not None
|
||||
assert role.name == 'moderator'
|
||||
assert role.rank == 2
|
||||
assert role.id is not None
|
||||
@@ -1,11 +1,26 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
# Import the actual StoredConversationMetadata from OpenHands core
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
)
|
||||
|
||||
# Mock the lazy import to return the actual class
|
||||
with patch(
|
||||
'storage.stored_conversation_metadata.StoredConversationMetadata',
|
||||
StoredConversationMetadata,
|
||||
):
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_call_sync_from_async():
|
||||
@@ -20,12 +35,25 @@ def mock_call_sync_from_async():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_user_store():
|
||||
"""Mock UserStore.get_user_by_id to return a mock user"""
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.current_org_id = UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
with patch(
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id',
|
||||
return_value=mock_user,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_and_get(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='my-conversation-id',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='my-repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime.now(UTC),
|
||||
@@ -47,13 +75,13 @@ async def test_save_and_get(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
|
||||
# Create test conversations with different timestamps
|
||||
conversations = [
|
||||
ConversationMetadata(
|
||||
conversation_id=f'conv-{i}',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime(2024, 1, i + 1, tzinfo=UTC),
|
||||
@@ -92,10 +120,10 @@ async def test_search(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_metadata(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='to-delete',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='repo',
|
||||
selected_branch=None,
|
||||
created_at=datetime.now(UTC),
|
||||
@@ -112,17 +140,17 @@ async def test_delete_metadata(session_maker):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_metadata(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await store.get_metadata('nonexistent-id')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exists(session_maker):
|
||||
store = SaasConversationStore('12345', session_maker)
|
||||
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
|
||||
metadata = ConversationMetadata(
|
||||
conversation_id='exists-test',
|
||||
user_id='12345',
|
||||
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
|
||||
selected_repository='repo',
|
||||
selected_branch='test-branch',
|
||||
created_at=datetime.now(UTC),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
@@ -19,6 +20,14 @@ def mock_config():
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock user with org_id."""
|
||||
user = MagicMock()
|
||||
user.current_org_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def secrets_store(session_maker, mock_config):
|
||||
return SaasSecretsStore('user-id', session_maker, mock_config)
|
||||
@@ -26,7 +35,11 @@ def secrets_store(session_maker, mock_config):
|
||||
|
||||
class TestSaasSecretsStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_and_load(self, secrets_store):
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
async def test_store_and_load(self, mock_get_user, secrets_store, mock_user):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Create a Secrets object with some test data
|
||||
user_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
@@ -59,7 +72,10 @@ class TestSaasSecretsStore:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encryption_decryption(self, secrets_store):
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
async def test_encryption_decryption(self, mock_get_user, secrets_store, mock_user):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
# Create a Secrets object with sensitive data
|
||||
user_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
@@ -89,6 +105,7 @@ class TestSaasSecretsStore:
|
||||
stored = (
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == 'user-id')
|
||||
.filter(StoredCustomSecrets.org_id == mock_user.current_org_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -152,7 +169,12 @@ class TestSaasSecretsStore:
|
||||
assert await secrets_store.load() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_existing_secrets(self, secrets_store):
|
||||
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
|
||||
async def test_update_existing_secrets(
|
||||
self, mock_get_user, secrets_store, mock_user
|
||||
):
|
||||
# Setup mock
|
||||
mock_get_user.return_value = mock_user
|
||||
# Create and store initial secrets
|
||||
initial_secrets = Secrets(
|
||||
custom_secrets=MappingProxyType(
|
||||
|
||||
@@ -2,65 +2,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
from server.constants import (
|
||||
CURRENT_USER_SETTINGS_VERSION,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
)
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_get_response():
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'user_info': {}})
|
||||
return mock_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_post_response():
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
return mock_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_api(mock_litellm_get_response, mock_litellm_post_response):
|
||||
api_key_patch = patch('storage.saas_settings_store.LITE_LLM_API_KEY', 'test_key')
|
||||
api_url_patch = patch(
|
||||
'storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
)
|
||||
team_id_patch = patch('storage.saas_settings_store.LITE_LLM_TEAM_ID', 'test_team')
|
||||
client_patch = patch('httpx.AsyncClient')
|
||||
|
||||
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_litellm_get_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_litellm_post_response
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stripe():
|
||||
search_patch = patch(
|
||||
'stripe.Customer.search_async',
|
||||
AsyncMock(return_value=MagicMock(id='mock-customer-id')),
|
||||
)
|
||||
payment_patch = patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[{}])),
|
||||
)
|
||||
with search_patch, payment_patch:
|
||||
yield
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -83,41 +35,42 @@ def mock_config():
|
||||
|
||||
@pytest.fixture
|
||||
def settings_store(session_maker, mock_config):
|
||||
store = SaasSettingsStore('user-id', session_maker, mock_config)
|
||||
store = SaasSettingsStore(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker, mock_config
|
||||
)
|
||||
|
||||
# Patch the store method directly to filter out email and email_verified
|
||||
original_load = store.load
|
||||
original_create_default = store.create_default_settings
|
||||
original_update_litellm = store.update_settings_with_litellm_default
|
||||
|
||||
# Patch the load method to add email and email_verified
|
||||
# Patch the load method to read from UserSettings table directly (for testing)
|
||||
async def patched_load():
|
||||
settings = await original_load()
|
||||
if settings:
|
||||
# Add email and email_verified fields to mimic SaasUserAuth behavior
|
||||
with store.session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == store.user_id)
|
||||
.first()
|
||||
)
|
||||
if not user_settings:
|
||||
# Return default settings
|
||||
return Settings(
|
||||
llm_api_key=SecretStr('test_api_key'),
|
||||
llm_base_url='http://test.url',
|
||||
agent='CodeActAgent',
|
||||
language='en',
|
||||
)
|
||||
|
||||
# Decrypt and convert to Settings
|
||||
kwargs = {}
|
||||
for column in UserSettings.__table__.columns:
|
||||
if column.name != 'keycloak_user_id':
|
||||
value = getattr(user_settings, column.name, None)
|
||||
if value is not None:
|
||||
kwargs[column.name] = value
|
||||
|
||||
store._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
settings.email = 'test@example.com'
|
||||
settings.email_verified = True
|
||||
return settings
|
||||
return settings
|
||||
|
||||
# Patch the create_default_settings method to add email and email_verified
|
||||
async def patched_create_default(settings):
|
||||
settings = await original_create_default(settings)
|
||||
if settings:
|
||||
# Add email and email_verified fields to mimic SaasUserAuth behavior
|
||||
settings.email = 'test@example.com'
|
||||
settings.email_verified = True
|
||||
return settings
|
||||
|
||||
# Patch the update_settings_with_litellm_default method
|
||||
async def patched_update_litellm(settings):
|
||||
updated_settings = await original_update_litellm(settings)
|
||||
if updated_settings:
|
||||
# Add email and email_verified fields to mimic SaasUserAuth behavior
|
||||
updated_settings.email = 'test@example.com'
|
||||
updated_settings.email_verified = True
|
||||
return updated_settings
|
||||
|
||||
# Patch the store method to filter out email and email_verified
|
||||
# Patch the store method to write to UserSettings table directly (for testing)
|
||||
async def patched_store(item):
|
||||
if item:
|
||||
# Make a copy of the item without email and email_verified
|
||||
@@ -146,11 +99,9 @@ def settings_store(session_maker, mock_config):
|
||||
for key, value in item_dict.items():
|
||||
if key in existing.__class__.__table__.columns:
|
||||
setattr(existing, key, value)
|
||||
existing.user_version = CURRENT_USER_SETTINGS_VERSION
|
||||
session.merge(existing)
|
||||
else:
|
||||
item_dict['keycloak_user_id'] = store.user_id
|
||||
item_dict['user_version'] = CURRENT_USER_SETTINGS_VERSION
|
||||
settings = UserSettings(**item_dict)
|
||||
session.add(settings)
|
||||
session.commit()
|
||||
@@ -158,8 +109,6 @@ def settings_store(session_maker, mock_config):
|
||||
# Replace the methods with our patched versions
|
||||
store.store = patched_store
|
||||
store.load = patched_load
|
||||
store.create_default_settings = patched_create_default
|
||||
store.update_settings_with_litellm_default = patched_update_litellm
|
||||
return store
|
||||
|
||||
|
||||
@@ -197,17 +146,11 @@ async def test_store_and_load_keycloak_user(settings_store):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_returns_default_when_not_found(
|
||||
settings_store, mock_litellm_api, mock_stripe, mock_github_user, session_maker
|
||||
):
|
||||
async def test_load_returns_default_when_not_found(settings_store, session_maker):
|
||||
file_store = MagicMock()
|
||||
file_store.read.side_effect = FileNotFoundError()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.saas_settings_store.get_file_store',
|
||||
MagicMock(return_value=file_store),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
loaded_settings = await settings_store.load()
|
||||
@@ -218,233 +161,9 @@ async def test_load_returns_default_when_not_found(
|
||||
assert loaded_settings.llm_base_url == 'http://test.url'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
settings = Settings()
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
settings = await settings_store.update_settings_with_litellm_default(settings)
|
||||
|
||||
assert settings.agent == 'CodeActAgent'
|
||||
assert settings.llm_api_key
|
||||
assert settings.llm_api_key.get_secret_value() == 'test_api_key'
|
||||
assert settings.llm_base_url == 'http://test.url'
|
||||
|
||||
# Get the actual call arguments
|
||||
call_args = mock_litellm_api.return_value.__aenter__.return_value.post.call_args[1]
|
||||
|
||||
# Check that the URL and most of the JSON payload match what we expect
|
||||
assert call_args['json']['user_email'] == 'testy@tester.com'
|
||||
assert call_args['json']['models'] == []
|
||||
assert call_args['json']['max_budget'] == 10.0
|
||||
assert call_args['json']['user_id'] == 'user-id'
|
||||
assert call_args['json']['teams'] == ['test_team']
|
||||
assert call_args['json']['auto_create_key'] is True
|
||||
assert call_args['json']['send_invite_email'] is False
|
||||
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
|
||||
assert 'model' in call_args['json']['metadata']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_no_user_id():
|
||||
store = SaasSettingsStore('', MagicMock(), MagicMock())
|
||||
settings = await store.create_default_settings(None)
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_require_payment_enabled(
|
||||
settings_store, mock_stripe
|
||||
):
|
||||
# Mock stripe_service.has_payment_method to return False
|
||||
with (
|
||||
patch('storage.saas_settings_store.REQUIRE_PAYMENT', True),
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
),
|
||||
patch(
|
||||
'integrations.stripe_service.session_maker', settings_store.session_maker
|
||||
),
|
||||
):
|
||||
settings = await settings_store.create_default_settings(None)
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_require_payment_disabled(
|
||||
settings_store, mock_stripe, mock_github_user, mock_litellm_api, session_maker
|
||||
):
|
||||
# Even without payment method, should get default settings when REQUIRE_PAYMENT is False
|
||||
file_store = MagicMock()
|
||||
file_store.read.side_effect = FileNotFoundError()
|
||||
with (
|
||||
patch('storage.saas_settings_store.REQUIRE_PAYMENT', False),
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
),
|
||||
patch(
|
||||
'storage.saas_settings_store.get_file_store',
|
||||
MagicMock(return_value=file_store),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
settings = await settings_store.create_default_settings(None)
|
||||
assert settings is not None
|
||||
assert settings.language == 'en'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_lite_llm_settings_no_api_config(settings_store):
|
||||
with (
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_KEY', None),
|
||||
patch('storage.saas_settings_store.LITE_LLM_API_URL', None),
|
||||
):
|
||||
settings = Settings()
|
||||
settings = await settings_store.update_settings_with_litellm_default(settings)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_default_error(settings_store):
|
||||
with patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'duplicate@example.com'}),
|
||||
):
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
AsyncMock(
|
||||
json=MagicMock(
|
||||
return_value={'user_info': {'max_budget': 10, 'spend': 5}}
|
||||
)
|
||||
)
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value.is_success = False
|
||||
settings = Settings()
|
||||
settings = await settings_store.update_settings_with_litellm_default(
|
||||
settings
|
||||
)
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_settings_with_litellm_retry_on_duplicate_email(
|
||||
settings_store, mock_litellm_api, session_maker
|
||||
):
|
||||
# First response is a delete and succeeds
|
||||
mock_delete_response = MagicMock()
|
||||
mock_delete_response.is_success = True
|
||||
mock_delete_response.status_code = 200
|
||||
|
||||
# Second response fails with duplicate email error
|
||||
mock_error_response = MagicMock()
|
||||
mock_error_response.is_success = False
|
||||
mock_error_response.status_code = 400
|
||||
mock_error_response.text = 'User with this email already exists'
|
||||
|
||||
# Thire response succeeds with no email
|
||||
mock_success_response = MagicMock()
|
||||
mock_success_response.is_success = True
|
||||
mock_success_response.json = MagicMock(return_value={'key': 'new_test_api_key'})
|
||||
|
||||
# Set up mocks
|
||||
post_mock = AsyncMock()
|
||||
post_mock.side_effect = [
|
||||
mock_delete_response,
|
||||
mock_error_response,
|
||||
mock_success_response,
|
||||
]
|
||||
mock_litellm_api.return_value.__aenter__.return_value.post = post_mock
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'duplicate@example.com'}),
|
||||
),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
settings = Settings()
|
||||
settings = await settings_store.update_settings_with_litellm_default(settings)
|
||||
|
||||
assert settings is not None
|
||||
assert settings.llm_api_key
|
||||
assert settings.llm_api_key.get_secret_value() == 'new_test_api_key'
|
||||
|
||||
# Verify second call was with email
|
||||
second_call_args = post_mock.call_args_list[1][1]
|
||||
assert second_call_args['json']['user_email'] == 'duplicate@example.com'
|
||||
|
||||
# Verify third call was with None for email
|
||||
third_call_args = post_mock.call_args_list[2][1]
|
||||
assert third_call_args['json']['user_email'] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_in_lite_llm(settings_store):
|
||||
# Test the _create_user_in_lite_llm method directly
|
||||
mock_client = AsyncMock()
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Test with email
|
||||
await settings_store._create_user_in_lite_llm(
|
||||
mock_client, 'test@example.com', 50, 10
|
||||
)
|
||||
|
||||
# Get the actual call arguments
|
||||
call_args = mock_client.post.call_args[1]
|
||||
|
||||
# Check that the URL and most of the JSON payload match what we expect
|
||||
assert call_args['json']['user_email'] == 'test@example.com'
|
||||
assert call_args['json']['models'] == []
|
||||
assert call_args['json']['max_budget'] == 50
|
||||
assert call_args['json']['spend'] == 10
|
||||
assert call_args['json']['user_id'] == 'user-id'
|
||||
assert call_args['json']['teams'] == [LITE_LLM_TEAM_ID]
|
||||
assert call_args['json']['auto_create_key'] is True
|
||||
assert call_args['json']['send_invite_email'] is False
|
||||
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
|
||||
assert 'model' in call_args['json']['metadata']
|
||||
|
||||
# Test with None email
|
||||
mock_client.post.reset_mock()
|
||||
await settings_store._create_user_in_lite_llm(mock_client, None, 25, 15)
|
||||
|
||||
# Get the actual call arguments
|
||||
call_args = mock_client.post.call_args[1]
|
||||
|
||||
# Check that the URL and most of the JSON payload match what we expect
|
||||
assert call_args['json']['user_email'] is None
|
||||
assert call_args['json']['models'] == []
|
||||
assert call_args['json']['max_budget'] == 25
|
||||
assert call_args['json']['spend'] == 15
|
||||
assert call_args['json']['user_id'] == str(settings_store.user_id)
|
||||
assert call_args['json']['teams'] == [LITE_LLM_TEAM_ID]
|
||||
assert call_args['json']['auto_create_key'] is True
|
||||
assert call_args['json']['send_invite_email'] is False
|
||||
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
|
||||
assert 'model' in call_args['json']['metadata']
|
||||
|
||||
# Verify response is returned correctly
|
||||
assert (
|
||||
await settings_store._create_user_in_lite_llm(
|
||||
mock_client, 'email@test.com', 30, 7
|
||||
)
|
||||
== mock_response
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encryption(settings_store):
|
||||
settings_store.user_id = 'mock-id' # GitHub user ID
|
||||
settings_store.user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' # GitHub user ID
|
||||
settings = Settings(
|
||||
llm_api_key=SecretStr('secret_key'),
|
||||
agent='smith',
|
||||
@@ -456,7 +175,9 @@ async def test_encryption(settings_store):
|
||||
with settings_store.session_maker() as session:
|
||||
stored = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == 'mock-id')
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# The stored key should be encrypted
|
||||
|
||||
@@ -535,3 +535,115 @@ def test_get_api_key_from_header_with_invalid_authorization_format():
|
||||
|
||||
# Assert that None was returned
|
||||
assert api_key is None
|
||||
|
||||
|
||||
def test_get_api_key_from_header_with_x_access_token():
|
||||
"""Test that get_api_key_from_header extracts API key from X-Access-Token header."""
|
||||
# Create a mock request with X-Access-Token header
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {'X-Access-Token': 'access_token_key'}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key was correctly extracted
|
||||
assert api_key == 'access_token_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_priority_authorization_over_x_access_token():
|
||||
"""Test that Authorization header takes priority over X-Access-Token header."""
|
||||
# Create a mock request with both headers
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'Bearer auth_api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from Authorization header was used
|
||||
assert api_key == 'auth_api_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_priority_x_session_over_x_access_token():
|
||||
"""Test that X-Session-API-Key header takes priority over X-Access-Token header."""
|
||||
# Create a mock request with both headers
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'X-Session-API-Key': 'session_api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from X-Session-API-Key header was used
|
||||
assert api_key == 'session_api_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_all_three_headers():
|
||||
"""Test header priority when all three headers are present."""
|
||||
# Create a mock request with all three headers
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'Bearer auth_api_key',
|
||||
'X-Session-API-Key': 'session_api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from Authorization header was used (highest priority)
|
||||
assert api_key == 'auth_api_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_invalid_authorization_fallback_to_x_access_token():
|
||||
"""Test that invalid Authorization header falls back to X-Access-Token."""
|
||||
# Create a mock request with invalid Authorization header and X-Access-Token
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'InvalidFormat api_key',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from X-Access-Token header was used
|
||||
assert api_key == 'access_token_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_empty_headers():
|
||||
"""Test that empty header values are handled correctly."""
|
||||
# Create a mock request with empty header values
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': '',
|
||||
'X-Session-API-Key': '',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that the API key from X-Access-Token header was used
|
||||
assert api_key == 'access_token_key'
|
||||
|
||||
|
||||
def test_get_api_key_from_header_bearer_with_empty_token():
|
||||
"""Test that Bearer header with empty token falls back to other headers."""
|
||||
# Create a mock request with Bearer header with empty token
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {
|
||||
'Authorization': 'Bearer ',
|
||||
'X-Access-Token': 'access_token_key',
|
||||
}
|
||||
|
||||
# Call the function
|
||||
api_key = get_api_key_from_header(mock_request)
|
||||
|
||||
# Assert that empty string from Bearer is returned (current behavior)
|
||||
# This tests the current implementation behavior
|
||||
assert api_key == ''
|
||||
|
||||
@@ -3,27 +3,30 @@ This test file verifies that the stripe_service functions properly use the datab
|
||||
to store and retrieve customer IDs.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import stripe
|
||||
from integrations.stripe_service import (
|
||||
find_customer_id_by_user_id,
|
||||
find_or_create_customer,
|
||||
find_or_create_customer_by_user_id,
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.stripe_customer import Base as StripeCustomerBase
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user_settings import Base as UserBase
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
|
||||
UserBase.metadata.create_all(engine)
|
||||
StripeCustomerBase.metadata.create_all(engine)
|
||||
# Create all tables using the unified Base
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@@ -32,79 +35,158 @@ def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_org_and_user(session_maker):
|
||||
"""Create a test org and user for use in tests."""
|
||||
test_user_id = uuid.uuid4()
|
||||
test_org_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
# Create role first
|
||||
role = Role(name='test-role', rank=1)
|
||||
session.add(role)
|
||||
session.flush()
|
||||
|
||||
# Create org
|
||||
org = Org(id=test_org_id, name='test-org', contact_email='testy@tester.com')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
# Create user with current_org_id
|
||||
user = User(id=test_user_id, current_org_id=test_org_id, role_id=role.id)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
# Create org member relationship
|
||||
org_member = OrgMember(
|
||||
org_id=test_org_id,
|
||||
user_id=test_user_id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
|
||||
return test_user_id, test_org_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_customer_id_by_user_id_checks_db_first(session_maker):
|
||||
async def test_find_customer_id_by_user_id_checks_db_first(
|
||||
session_maker, test_org_and_user
|
||||
):
|
||||
"""Test that find_customer_id_by_user_id checks the database first"""
|
||||
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
|
||||
# Set up the mock for the database query result
|
||||
with session_maker() as session:
|
||||
# Create stripe customer
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id='test-user-id',
|
||||
keycloak_user_id=str(test_user_id),
|
||||
org_id=test_org_id,
|
||||
stripe_customer_id='cus_test123',
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
with patch('integrations.stripe_service.session_maker', session_maker):
|
||||
# Create a mock org object to return from OrgStore
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = test_org_id
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
|
||||
# Call the function
|
||||
result = await find_customer_id_by_user_id('test-user-id')
|
||||
result = await find_customer_id_by_user_id(str(test_user_id))
|
||||
|
||||
# Verify the result
|
||||
assert result == 'cus_test123'
|
||||
|
||||
# Verify that call_sync_from_async was called with the correct function
|
||||
mock_call_sync.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_customer_id_by_user_id_falls_back_to_stripe(session_maker):
|
||||
async def test_find_customer_id_by_user_id_falls_back_to_stripe(
|
||||
session_maker, test_org_and_user
|
||||
):
|
||||
"""Test that find_customer_id_by_user_id falls back to Stripe if not found in the database"""
|
||||
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
|
||||
# Set up the mock for stripe.Customer.search_async
|
||||
mock_customer = stripe.Customer(id='cus_test123')
|
||||
mock_search = AsyncMock(return_value=MagicMock(data=[mock_customer]))
|
||||
|
||||
# Create a mock org object to return from OrgStore
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = test_org_id
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('stripe.Customer.search_async', mock_search),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
|
||||
# Call the function
|
||||
result = await find_customer_id_by_user_id('test-user-id')
|
||||
result = await find_customer_id_by_user_id(str(test_user_id))
|
||||
|
||||
# Verify the result
|
||||
assert result == 'cus_test123'
|
||||
|
||||
# Verify that Stripe was searched
|
||||
# Verify that Stripe was searched with the org_id
|
||||
mock_search.assert_called_once()
|
||||
assert "metadata['user_id']:'test-user-id'" in mock_search.call_args[1]['query']
|
||||
assert (
|
||||
f"metadata['org_id']:'{str(test_org_id)}'" in mock_search.call_args[1]['query']
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_customer_stores_id_in_db(session_maker):
|
||||
async def test_create_customer_stores_id_in_db(session_maker, test_org_and_user):
|
||||
"""Test that create_customer stores the customer ID in the database"""
|
||||
|
||||
# Set up the mock for stripe.Customer.search_async
|
||||
test_user_id, test_org_id = test_org_and_user
|
||||
|
||||
# Set up the mock for stripe.Customer.search_async and create_async
|
||||
mock_search = AsyncMock(return_value=MagicMock(data=[]))
|
||||
mock_create_async = AsyncMock(return_value=stripe.Customer(id='cus_test123'))
|
||||
|
||||
# Create a mock org object to return from OrgStore
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = test_org_id
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('stripe.Customer.search_async', mock_search),
|
||||
patch('stripe.Customer.create_async', mock_create_async),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
|
||||
):
|
||||
# Mock the call_sync_from_async to return the org
|
||||
mock_call_sync.return_value = mock_org
|
||||
|
||||
# Call the function
|
||||
result = await find_or_create_customer('test-user-id')
|
||||
result = await find_or_create_customer_by_user_id(str(test_user_id))
|
||||
|
||||
# Verify the result
|
||||
assert result == 'cus_test123'
|
||||
assert result == {'customer_id': 'cus_test123', 'org_id': str(test_org_id)}
|
||||
|
||||
# Verify that the stripe customer was stored in the db
|
||||
with session_maker() as session:
|
||||
customer = session.query(StripeCustomer).first()
|
||||
assert customer.id > 0
|
||||
assert customer.keycloak_user_id == 'test-user-id'
|
||||
assert customer.keycloak_user_id == str(test_user_id)
|
||||
assert customer.org_id == test_org_id
|
||||
assert customer.stripe_customer_id == 'cus_test123'
|
||||
assert customer.created_at is not None
|
||||
assert customer.updated_at is not None
|
||||
|
||||
164
enterprise/tests/unit/test_user_store.py
Normal file
164
enterprise/tests/unit/test_user_store.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
# Mock the database module before importing UserStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from sqlalchemy.orm import configure_mappers
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope='session')
|
||||
def load_all_models():
|
||||
configure_mappers() # fail fast if anything’s missing
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_api():
|
||||
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
|
||||
api_url_patch = patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
|
||||
)
|
||||
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
|
||||
client_patch = patch('httpx.AsyncClient')
|
||||
|
||||
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_response
|
||||
)
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stripe():
|
||||
search_patch = patch(
|
||||
'stripe.Customer.search_async',
|
||||
AsyncMock(return_value=MagicMock(id='mock-customer-id')),
|
||||
)
|
||||
payment_patch = patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[{}])),
|
||||
)
|
||||
with search_patch, payment_patch:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_no_org_id():
|
||||
# Test UserStore.create_default_settings with empty org_id
|
||||
settings = await UserStore.create_default_settings('', 'test-user-id')
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_require_org(session_maker, mock_stripe):
|
||||
# Mock stripe_service.has_payment_method to return False
|
||||
with (
|
||||
patch(
|
||||
'stripe.Customer.list_payment_methods_async',
|
||||
AsyncMock(return_value=MagicMock(data=[])),
|
||||
),
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
):
|
||||
settings = await UserStore.create_default_settings(
|
||||
'test-org-id', 'test-user-id'
|
||||
)
|
||||
assert settings is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_with_litellm(session_maker, mock_litellm_api):
|
||||
# Test that UserStore.create_default_settings works with LiteLLM
|
||||
with (
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch('storage.user_store.session_maker', session_maker),
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'attributes': {'github_id': ['12345']}}),
|
||||
),
|
||||
):
|
||||
settings = await UserStore.create_default_settings(
|
||||
'test-org-id', 'test-user-id'
|
||||
)
|
||||
assert settings is not None
|
||||
assert settings.llm_api_key.get_secret_value() == 'test_api_key'
|
||||
assert settings.llm_base_url == 'http://test.url'
|
||||
assert settings.agent == 'CodeActAgent'
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Complex integration test with session isolation issues')
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user(session_maker, mock_litellm_api):
|
||||
# Test creating a new user - skipped due to complex session isolation issues
|
||||
pass
|
||||
|
||||
|
||||
def test_get_user_by_id(session_maker):
|
||||
# Test getting user by ID
|
||||
test_org_id = uuid.uuid4()
|
||||
test_user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
with session_maker() as session:
|
||||
# Create a test user
|
||||
user = User(id=uuid.UUID(test_user_id), current_org_id=test_org_id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
user_id = user.id
|
||||
|
||||
# Test retrieval
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
retrieved_user = UserStore.get_user_by_id(test_user_id)
|
||||
assert retrieved_user is not None
|
||||
assert retrieved_user.id == user_id
|
||||
|
||||
|
||||
def test_list_users(session_maker):
|
||||
# Test listing all users
|
||||
test_org_id1 = uuid.uuid4()
|
||||
test_org_id2 = uuid.uuid4()
|
||||
test_user_id1 = uuid.uuid4()
|
||||
test_user_id2 = uuid.uuid4()
|
||||
with session_maker() as session:
|
||||
# Create test users
|
||||
user1 = User(id=test_user_id1, current_org_id=test_org_id1)
|
||||
user2 = User(id=test_user_id2, current_org_id=test_org_id2)
|
||||
session.add_all([user1, user2])
|
||||
session.commit()
|
||||
|
||||
# Test listing
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
users = UserStore.list_users()
|
||||
assert len(users) >= 2
|
||||
user_ids = [user.id for user in users]
|
||||
assert test_user_id1 in user_ids
|
||||
assert test_user_id2 in user_ids
|
||||
|
||||
|
||||
def test_get_kwargs_from_settings():
|
||||
# Test extracting user kwargs from settings
|
||||
settings = Settings(
|
||||
language='es',
|
||||
enable_sound_notifications=True,
|
||||
llm_api_key=SecretStr('test-key'),
|
||||
)
|
||||
|
||||
kwargs = UserStore.get_kwargs_from_settings(settings)
|
||||
|
||||
# Should only include fields that exist in User model
|
||||
assert 'language' in kwargs
|
||||
assert 'enable_sound_notifications' in kwargs
|
||||
# Should not include fields that don't exist in User model
|
||||
assert 'llm_api_key' not in kwargs
|
||||
65
evaluation/benchmarks/swefficiency/README.md
Normal file
65
evaluation/benchmarks/swefficiency/README.md
Normal file
@@ -0,0 +1,65 @@
|
||||
# SWE-fficiency Evaluation
|
||||
|
||||
This folder contains the OpenHands inference generation of the [SWE-fficiency benchmark](https://swefficiency.com/) ([paper](https://arxiv.org/pdf/2507.12415v1)).
|
||||
|
||||
The evaluation consists of three steps:
|
||||
|
||||
1. Environment setup: [install python environment](../../README.md#development-environment) and [configure LLM config](../../README.md#configure-openhands-and-your-llm).
|
||||
2. [Run inference](#running-inference-locally-with-docker): Generate a edit patch for each Github issue
|
||||
3. [Evaluate patches](#evaluate-generated-patches)
|
||||
|
||||
## Setup Environment and LLM Configuration
|
||||
|
||||
Please follow instruction [here](../../README.md#setup) to setup your local development environment and LLM.
|
||||
|
||||
## Running inference Locally with Docker
|
||||
|
||||
Make sure your Docker daemon is running, and you have ample disk space (at least 200-500GB, depends on the SWE-PErf set you are running on) for the instance-level docker image.
|
||||
|
||||
When the `run_infer.sh` script is started, it will automatically pull the relevant SWE-Perf images.
|
||||
For example, for instance ID `scikit-learn_scikit-learn-11674`, it will try to pull our pre-build docker image `betty1202/sweb.eval.x86_64.scikit-learn_s_scikit-learn-11674` from DockerHub.
|
||||
This image will be used create an OpenHands runtime image where the agent will operate on.
|
||||
|
||||
```bash
|
||||
./evaluation/benchmarks/swefficiency/scripts/run_infer.sh [model_config] [git-version] [agent] [eval_limit] [max_iter] [num_workers] [dataset] [dataset_split] [n_runs] [mode]
|
||||
|
||||
# Example
|
||||
./evaluation/benchmarks/swefficiency/scripts/run_infer.sh llm.eval_gpt4_1106_preview HEAD CodeActAgent 500 100 1 swefficiency/swefficiency test
|
||||
```
|
||||
|
||||
where `model_config` is mandatory, and the rest are optional.
|
||||
|
||||
- `model_config`, e.g. `eval_gpt4_1106_preview`, is the config group name for your
|
||||
LLM settings, as defined in your `config.toml`.
|
||||
- `git-version`, e.g. `HEAD`, is the git commit hash of the OpenHands version you would
|
||||
like to evaluate. It could also be a release tag like `0.6.2`.
|
||||
- `agent`, e.g. `CodeActAgent`, is the name of the agent for benchmarks, defaulting
|
||||
to `CodeActAgent`.
|
||||
- `eval_limit`, e.g. `10`, limits the evaluation to the first `eval_limit` instances. By
|
||||
default, the script evaluates the entire SWE-Perf test set (140 issues). Note:
|
||||
in order to use `eval_limit`, you must also set `agent`.
|
||||
- `max_iter`, e.g. `20`, is the maximum number of iterations for the agent to run. By
|
||||
default, it is set to 100.
|
||||
- `num_workers`, e.g. `3`, is the number of parallel workers to run the evaluation. By
|
||||
default, it is set to 1.
|
||||
- `dataset`, a huggingface dataset name. e.g. `SWE-Perf/SWE-Perf`, specifies which dataset to evaluate on.
|
||||
- `dataset_split`, split for the huggingface dataset. e.g., `test`, `dev`. Default to `test`.
|
||||
|
||||
- `n_runs`, e.g. `3`, is the number of times to run the evaluation. Default is 1.
|
||||
- `mode`, e.g. `swt`, `swt-ci`, or `swe`, specifies the evaluation mode. Default is `swe`.
|
||||
|
||||
> [!CAUTION]
|
||||
> Setting `num_workers` larger than 1 is not officially tested, YMMV.
|
||||
|
||||
|
||||
Let's say you'd like to run 10 instances using `llm.eval_gpt4_1106_preview` and CodeActAgent,
|
||||
|
||||
then your command would be:
|
||||
|
||||
```bash
|
||||
./evaluation/benchmarks/swe_bench/scripts/run_infer.sh llm.eval_gpt4_1106_preview HEAD CodeActAgent 10
|
||||
```
|
||||
|
||||
### 2. Run the SWE-fficiency benchmark official evaluation
|
||||
|
||||
Once the output is converted, use the [official SWE-fficiency benchmark evaluation](https://github.com/swefficiency/swefficiency) to evaluate it.
|
||||
0
evaluation/benchmarks/swefficiency/__init__.py
Normal file
0
evaluation/benchmarks/swefficiency/__init__.py
Normal file
52
evaluation/benchmarks/swefficiency/binary_patch_utils.py
Normal file
52
evaluation/benchmarks/swefficiency/binary_patch_utils.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Utilities for handling binary files and patch generation in SWE-bench evaluation.
|
||||
"""
|
||||
|
||||
|
||||
def remove_binary_diffs(patch_text):
|
||||
"""
|
||||
Remove binary file diffs from a git patch.
|
||||
|
||||
Args:
|
||||
patch_text (str): The git patch text
|
||||
|
||||
Returns:
|
||||
str: The cleaned patch text with binary diffs removed
|
||||
"""
|
||||
lines = patch_text.splitlines()
|
||||
cleaned_lines = []
|
||||
block = []
|
||||
is_binary_block = False
|
||||
|
||||
for line in lines:
|
||||
if line.startswith('diff --git '):
|
||||
if block and not is_binary_block:
|
||||
cleaned_lines.extend(block)
|
||||
block = [line]
|
||||
is_binary_block = False
|
||||
elif 'Binary files' in line:
|
||||
is_binary_block = True
|
||||
block.append(line)
|
||||
else:
|
||||
block.append(line)
|
||||
|
||||
if block and not is_binary_block:
|
||||
cleaned_lines.extend(block)
|
||||
return '\n'.join(cleaned_lines)
|
||||
|
||||
|
||||
def remove_binary_files_from_git():
|
||||
"""
|
||||
Generate a bash command to remove binary files from git staging.
|
||||
|
||||
Returns:
|
||||
str: A bash command that removes binary files from git staging
|
||||
"""
|
||||
return """
|
||||
for file in $(git status --porcelain | grep -E "^(M| M|\\?\\?|A| A)" | cut -c4-); do
|
||||
if [ -f "$file" ] && (file "$file" | grep -q "executable" || git check-attr binary "$file" | grep -q "binary: set"); then
|
||||
git rm -f "$file" 2>/dev/null || rm -f "$file"
|
||||
echo "Removed: $file"
|
||||
fi
|
||||
done
|
||||
""".strip()
|
||||
960
evaluation/benchmarks/swefficiency/run_infer.py
Normal file
960
evaluation/benchmarks/swefficiency/run_infer.py
Normal file
@@ -0,0 +1,960 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Literal
|
||||
|
||||
import pandas as pd
|
||||
import toml
|
||||
from datasets import load_dataset
|
||||
|
||||
import openhands.agenthub
|
||||
from evaluation.benchmarks.swe_bench.binary_patch_utils import (
|
||||
remove_binary_diffs,
|
||||
remove_binary_files_from_git,
|
||||
)
|
||||
from evaluation.utils.shared import (
|
||||
EvalException,
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
assert_and_raise,
|
||||
codeact_user_response,
|
||||
get_default_sandbox_config_for_eval,
|
||||
get_metrics,
|
||||
is_fatal_evaluation_error,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
update_llm_config_for_completions_logging,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
AgentConfig,
|
||||
OpenHandsConfig,
|
||||
get_evaluation_parser,
|
||||
get_llm_config_arg,
|
||||
)
|
||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||
from openhands.core.config.utils import get_condenser_config_arg
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.critic import AgentFinishedCritic
|
||||
from openhands.events.action import CmdRunAction, FileReadAction, MessageAction
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
ErrorObservation,
|
||||
FileReadObservation,
|
||||
)
|
||||
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
from openhands.utils.shutdown_listener import sleep_if_should_continue
|
||||
|
||||
USE_HINT_TEXT = os.environ.get('USE_HINT_TEXT', 'false').lower() == 'true'
|
||||
RUN_WITH_BROWSING = os.environ.get('RUN_WITH_BROWSING', 'false').lower() == 'true'
|
||||
BenchMode = Literal['swe', 'swt', 'swt-ci']
|
||||
|
||||
|
||||
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
'CodeActAgent': codeact_user_response,
|
||||
}
|
||||
|
||||
|
||||
def _get_swebench_workspace_dir_name(instance: pd.Series) -> str:
|
||||
return f'{instance.repo}__{instance.version}'.replace('/', '__')
|
||||
|
||||
|
||||
def get_instruction(instance: pd.Series, metadata: EvalMetadata) -> MessageAction:
|
||||
workspace_dir_name = _get_swebench_workspace_dir_name(instance)
|
||||
|
||||
# TODO: Change to testbed?
|
||||
instruction = f"""
|
||||
<uploaded_files>
|
||||
/workspace/{workspace_dir_name}
|
||||
</uploaded_files>
|
||||
|
||||
I’ve uploaded a python code repository in the directory workspace_dir_name. Consider the following performance workload and `workload()` function showing an specific usage of the repository:
|
||||
<performance_workload>
|
||||
{instance.workload}
|
||||
</performance_workload>
|
||||
|
||||
Can you help me implement the necessary changes to the repository so that the runtime of the `workload()` function is faster? Basic guidelines:
|
||||
1. Your task is to make changes to non-test files in the /workspace directory to improve the performance of the code running in `workload()`. Please do not directly change the implementation of the `workload()` function to optimize things: I want you to focus on making the workload AS IS run faster by only editing the repository containing code that the `workload()` function calls.
|
||||
2. Make changes while ensuring the repository is functionally equivalent to the original: your changes should not introduce new bugs or cause already-passing tests to begin failing after your changes. However, you do not need to worry about tests that already fail without any changes made. For relevant test files you find in the repository, you can run them via the bash command `{instance.test_cmd} <test_file>` to check for correctness. Note that running all the tests may take a long time, so you need to determine which tests are relevant to your changes.
|
||||
3. Make sure the `workload()` function improves in performance after you make changes to the repository. The workload can potentially take some time to run, so please allow it to finish and be generous with setting your timeout parameter (a timeout value of 3600 or larger here is encouraged): for faster iteration, you should adjust the workload script to use fewer iterations. Before you complete your task, please make sure to check that the **original performance workload** and `workload()` function runs successfully and the performance is improved.
|
||||
4. You may need to reinstall/rebuild the repo for your changes to take effect before testing if you made non-Python changes. Reinstalling may take a long time to run (a timeout value of 3600 or larger here is encouraged), so please be patient with running it and allow it to complete if possible. You can reinstall the repository by running the bash command `{instance.rebuild_cmd}` in the workspace directory.
|
||||
5. All the dependencies required to run the `workload()` function are already installed in the environment. You should not install or upgrade any dependencies.
|
||||
|
||||
Follow these steps to improve performance:
|
||||
1. As a first step, explore the repository structure.
|
||||
2. Create a Python script to reproduce the performance workload, execute it with python <workload_file>, and examine the printed output metrics.
|
||||
3. Edit the source code of the repository to improve performance. Please do not change the contents of the `workload()` function itself, but focus on optimizing the code in the repository that the original `workload()` function uses.
|
||||
4. If non-Python changes were made, rebuild the repo to make sure the changes take effect.
|
||||
5. Rerun your script to confirm that performance has improved.
|
||||
6. If necessary, identify any relevant test files in the repository related to your changes and verify that test statuses did not change after your modifications.
|
||||
7. After each attempted change, please reflect on the changes attempted and the performance impact observed. If the performance did not improve, consider alternative approaches or optimizations.
|
||||
8. Once you are satisfied, please use the finish command to complete your task.
|
||||
|
||||
Please remember that you should not change the implementation of the `workload()` function. The performance improvement should solely come from editing the source files in the code repository.
|
||||
"""
|
||||
|
||||
if RUN_WITH_BROWSING:
|
||||
instruction += (
|
||||
'<IMPORTANT!>\nYou SHOULD NEVER attempt to browse the web. </IMPORTANT!>\n'
|
||||
)
|
||||
|
||||
return MessageAction(content=instruction)
|
||||
|
||||
|
||||
def get_instance_docker_image(
|
||||
instance_id: str,
|
||||
) -> str:
|
||||
return f'ghcr.io/swefficiency/swefficiency-images:{instance_id}'
|
||||
|
||||
|
||||
def get_config(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
cpu_group: list[int] | None = None,
|
||||
) -> OpenHandsConfig:
|
||||
# We use a different instance image for the each instance of swe-bench eval
|
||||
base_container_image = get_instance_docker_image(
|
||||
instance['instance_id'],
|
||||
)
|
||||
logger.info(
|
||||
f'Using instance container image: {base_container_image}. '
|
||||
f'Please make sure this image exists. '
|
||||
f'Submit an issue on https://github.com/All-Hands-AI/OpenHands if you run into any issues.'
|
||||
)
|
||||
|
||||
sandbox_config = get_default_sandbox_config_for_eval()
|
||||
sandbox_config.base_container_image = base_container_image
|
||||
sandbox_config.enable_auto_lint = True
|
||||
sandbox_config.use_host_network = False
|
||||
sandbox_config.timeout = 3600
|
||||
|
||||
# Control container cleanup behavior via environment variable
|
||||
# Default to False for multiprocessing stability to prevent cascade failures
|
||||
sandbox_config.rm_all_containers = True
|
||||
|
||||
sandbox_config.platform = 'linux/amd64'
|
||||
sandbox_config.remote_runtime_resource_factor = 4.0
|
||||
sandbox_config.runtime_startup_env_vars.update(
|
||||
{
|
||||
'NO_CHANGE_TIMEOUT_SECONDS': '900', # 15 minutes
|
||||
}
|
||||
)
|
||||
|
||||
if cpu_group is not None:
|
||||
print(f'Configuring Docker runtime with CPU group: {cpu_group}')
|
||||
sandbox_config.docker_runtime_kwargs = {
|
||||
# HACK: Use the cpu_group if provided, otherwise use all available CPUs
|
||||
'cpuset_cpus': ','.join(map(str, cpu_group)),
|
||||
'nano_cpus': int(1e9 * len(cpu_group)), # optional: hard cap to vCPU count
|
||||
'mem_limit': '16g',
|
||||
}
|
||||
|
||||
# Note: We keep rm_all_containers = False for worker process safety
|
||||
|
||||
config = OpenHandsConfig(
|
||||
default_agent=metadata.agent_class,
|
||||
run_as_openhands=False,
|
||||
max_iterations=metadata.max_iterations,
|
||||
runtime=os.environ.get('RUNTIME', 'docker'),
|
||||
sandbox=sandbox_config,
|
||||
# do not mount workspace
|
||||
workspace_base=None,
|
||||
workspace_mount_path=None,
|
||||
)
|
||||
config.set_llm_config(
|
||||
update_llm_config_for_completions_logging(
|
||||
metadata.llm_config, metadata.eval_output_dir, instance['instance_id']
|
||||
)
|
||||
)
|
||||
agent_config = AgentConfig(
|
||||
enable_jupyter=False,
|
||||
enable_browsing=RUN_WITH_BROWSING,
|
||||
enable_llm_editor=False,
|
||||
enable_mcp=False,
|
||||
condenser=metadata.condenser_config,
|
||||
enable_prompt_extensions=False,
|
||||
)
|
||||
config.set_agent_config(agent_config)
|
||||
return config
|
||||
|
||||
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required
|
||||
metadata: EvalMetadata,
|
||||
):
|
||||
"""Initialize the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
"""
|
||||
logger.info('-' * 30)
|
||||
logger.info('BEGIN Runtime Initialization Fn')
|
||||
logger.info('-' * 30)
|
||||
workspace_dir_name = _get_swebench_workspace_dir_name(instance)
|
||||
obs: CmdOutputObservation
|
||||
|
||||
# Set instance id and git configuration
|
||||
action = CmdRunAction(
|
||||
command=f"""echo 'export SWE_INSTANCE_ID={instance['instance_id']}' >> ~/.bashrc && echo 'export PIP_CACHE_DIR=~/.cache/pip' >> ~/.bashrc && echo "alias git='git --no-pager'" >> ~/.bashrc && git config --global core.pager "" && git config --global diff.binary false"""
|
||||
)
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to export SWE_INSTANCE_ID and configure git: {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {str(obs)}')
|
||||
|
||||
# inject the init script
|
||||
script_dir = os.path.dirname(__file__)
|
||||
|
||||
# inject the instance info
|
||||
action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to create /swe_util/eval_data/instances: {str(obs)}',
|
||||
)
|
||||
|
||||
swe_instance_json_name = 'swe-bench-instance.json'
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Construct the full path for the desired file name within the temporary directory
|
||||
temp_file_path = os.path.join(temp_dir, swe_instance_json_name)
|
||||
# Write to the file with the desired name within the temporary directory
|
||||
with open(temp_file_path, 'w') as f:
|
||||
if not isinstance(instance, dict):
|
||||
json.dump([instance.to_dict()], f)
|
||||
else:
|
||||
json.dump([instance], f)
|
||||
|
||||
# Copy the file to the desired location
|
||||
runtime.copy_to(temp_file_path, '/swe_util/eval_data/instances/')
|
||||
|
||||
# inject the instance swe entry
|
||||
runtime.copy_to(
|
||||
str(os.path.join(script_dir, 'scripts/setup/instance_swe_entry.sh')),
|
||||
'/swe_util/',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='cat ~/.bashrc')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(command='source ~/.bashrc')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
if isinstance(obs, ErrorObservation):
|
||||
logger.error(f'Failed to source ~/.bashrc: {str(obs)}')
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to source ~/.bashrc: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to source /swe_util/instance_swe_entry.sh: {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='git reset --hard')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to git reset --hard: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(
|
||||
command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
|
||||
)
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(command='which python')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0 and 'testbed' in obs.content,
|
||||
f'Expected to find python interpreter from testbed, but got: {str(obs)}',
|
||||
)
|
||||
|
||||
logger.info('-' * 30)
|
||||
logger.info('END Runtime Initialization Fn')
|
||||
logger.info('-' * 30)
|
||||
|
||||
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
|
||||
) -> dict[str, Any]:
|
||||
"""Complete the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
If you need to do something in the sandbox to get the correctness metric after
|
||||
the agent has run, modify this function.
|
||||
"""
|
||||
logger.info('-' * 30)
|
||||
logger.info('BEGIN Runtime Completion Fn')
|
||||
logger.info('-' * 30)
|
||||
obs: CmdOutputObservation
|
||||
workspace_dir_name = _get_swebench_workspace_dir_name(instance)
|
||||
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
if obs.exit_code == -1:
|
||||
# The previous command is still running
|
||||
# We need to kill previous command
|
||||
logger.info('The previous command is still running, trying to kill it...')
|
||||
action = CmdRunAction(command='C-c')
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
# Then run the command again
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
if obs.exit_code == -1:
|
||||
# The previous command is still running
|
||||
# We need to kill previous command
|
||||
logger.info('The previous command is still running, trying to ctrl+z it...')
|
||||
action = CmdRunAction(command='C-z')
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
# Then run the command again
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
assert_and_raise(
|
||||
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='git config --global core.pager ""')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
|
||||
f'Failed to git config --global core.pager "": {str(obs)}',
|
||||
)
|
||||
|
||||
# First check for any git repositories in subdirectories
|
||||
action = CmdRunAction(command='find . -type d -name .git -not -path "./.git"')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
|
||||
f'Failed to find git repositories: {str(obs)}',
|
||||
)
|
||||
|
||||
git_dirs = [p for p in obs.content.strip().split('\n') if p]
|
||||
if git_dirs:
|
||||
# Remove all .git directories in subdirectories
|
||||
for git_dir in git_dirs:
|
||||
action = CmdRunAction(command=f'rm -rf "{git_dir}"')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
|
||||
f'Failed to remove git directory {git_dir}: {str(obs)}',
|
||||
)
|
||||
|
||||
# add all files
|
||||
action = CmdRunAction(command='git add -A')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
|
||||
f'Failed to git add -A: {str(obs)}',
|
||||
)
|
||||
|
||||
# Remove binary files from git staging
|
||||
action = CmdRunAction(command=remove_binary_files_from_git())
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
|
||||
f'Failed to remove binary files: {str(obs)}',
|
||||
)
|
||||
|
||||
n_retries = 0
|
||||
git_patch = None
|
||||
while n_retries < 5:
|
||||
action = CmdRunAction(
|
||||
command=f'git diff --no-color --cached {instance["base_commit"]} > patch.diff'
|
||||
)
|
||||
action.set_hard_timeout(max(300 + 100 * n_retries, 600))
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
n_retries += 1
|
||||
if isinstance(obs, CmdOutputObservation):
|
||||
if obs.exit_code == 0:
|
||||
# Read the patch file
|
||||
action = FileReadAction(path='patch.diff')
|
||||
action.set_hard_timeout(max(300 + 100 * n_retries, 600))
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
if isinstance(obs, FileReadObservation):
|
||||
git_patch = obs.content
|
||||
break
|
||||
elif isinstance(obs, ErrorObservation):
|
||||
# Fall back to cat "patch.diff" to get the patch
|
||||
assert 'File could not be decoded as utf-8' in obs.content
|
||||
action = CmdRunAction(command='cat patch.diff')
|
||||
action.set_hard_timeout(max(300 + 100 * n_retries, 600))
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
assert isinstance(obs, CmdOutputObservation) and obs.exit_code == 0
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
git_patch = obs.content
|
||||
break
|
||||
else:
|
||||
assert_and_raise(False, f'Unexpected observation type: {str(obs)}')
|
||||
else:
|
||||
logger.info('Failed to get git diff, retrying...')
|
||||
sleep_if_should_continue(10)
|
||||
elif isinstance(obs, ErrorObservation):
|
||||
logger.error(f'Error occurred: {obs.content}. Retrying...')
|
||||
sleep_if_should_continue(10)
|
||||
else:
|
||||
assert_and_raise(False, f'Unexpected observation type: {str(obs)}')
|
||||
|
||||
assert_and_raise(git_patch is not None, 'Failed to get git diff (None)')
|
||||
|
||||
# Remove binary diffs from the patch
|
||||
git_patch = remove_binary_diffs(git_patch)
|
||||
|
||||
logger.info('-' * 30)
|
||||
logger.info('END Runtime Completion Fn')
|
||||
logger.info('-' * 30)
|
||||
return {'git_patch': git_patch}
|
||||
|
||||
|
||||
class CPUGroupManager:
|
||||
def __init__(self, cpu_groups_queue: multiprocessing.Queue):
|
||||
self.cpu_groups_queue = cpu_groups_queue
|
||||
|
||||
def __enter__(self):
|
||||
# Get the current CPU group for this worker]
|
||||
if self.cpu_groups_queue is not None:
|
||||
self.cpu_group = self.cpu_groups_queue.get()
|
||||
logger.info(f'Worker started with CPU group: {self.cpu_group}')
|
||||
return self.cpu_group
|
||||
return None
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
# Put the CPU group back into the queue for other workers to use
|
||||
if self.cpu_groups_queue is not None:
|
||||
self.cpu_groups_queue.put(self.cpu_group)
|
||||
logger.info(f'Worker finished with CPU group: {self.cpu_group}')
|
||||
|
||||
|
||||
def cleanup_docker_resources_for_worker():
|
||||
"""Clean up Docker resources specific to this worker process.
|
||||
|
||||
This prevents cascade failures when one worker's container crashes.
|
||||
Note: This only cleans up stale locks, not containers, to avoid
|
||||
interfering with other workers. Container cleanup is handled
|
||||
by the DockerRuntime.close() method based on configuration.
|
||||
"""
|
||||
|
||||
# Clean up any stale port locks from crashed processes
|
||||
try:
|
||||
from openhands.runtime.utils.port_lock import cleanup_stale_locks
|
||||
|
||||
cleanup_stale_locks(max_age_seconds=300) # Clean up locks older than 5 minutes
|
||||
except Exception as e:
|
||||
logger.debug(f'Error cleaning up stale port locks: {e}')
|
||||
|
||||
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
runtime_failure_count: int = 0,
|
||||
cpu_groups_queue: multiprocessing.Queue = None,
|
||||
) -> EvalOutput:
|
||||
# Clean up any Docker resources from previous failed runs
|
||||
cleanup_docker_resources_for_worker()
|
||||
|
||||
# HACK: Use the global and get the cpu group for this worker.
|
||||
with CPUGroupManager(cpu_groups_queue) as cpu_group:
|
||||
config = get_config(instance, metadata, cpu_group=cpu_group)
|
||||
|
||||
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
|
||||
if reset_logger:
|
||||
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
|
||||
reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
|
||||
else:
|
||||
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
|
||||
|
||||
metadata = copy.deepcopy(metadata)
|
||||
metadata.details['runtime_failure_count'] = runtime_failure_count
|
||||
metadata.details['remote_runtime_resource_factor'] = (
|
||||
config.sandbox.remote_runtime_resource_factor
|
||||
)
|
||||
|
||||
runtime = create_runtime(config, sid=None)
|
||||
call_async_from_sync(runtime.connect)
|
||||
|
||||
try:
|
||||
initialize_runtime(runtime, instance, metadata)
|
||||
|
||||
message_action = get_instruction(instance, metadata)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=message_action,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# if fatal error, throw EvalError to trigger re-run
|
||||
if is_fatal_evaluation_error(state.last_error):
|
||||
raise EvalException('Fatal error detected: ' + state.last_error)
|
||||
|
||||
# ======= THIS IS SWE-Bench specific =======
|
||||
# Get git patch
|
||||
return_val = complete_runtime(runtime, instance)
|
||||
git_patch = return_val['git_patch']
|
||||
logger.info(
|
||||
f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------'
|
||||
)
|
||||
except Exception as e:
|
||||
# Log the error but don't let it crash other workers
|
||||
logger.error(
|
||||
f'Error in worker processing instance {instance.instance_id}: {str(e)}'
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Ensure runtime is properly closed to prevent cascade failures
|
||||
try:
|
||||
runtime.close()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Error closing runtime for {instance.instance_id}: {str(e)}'
|
||||
)
|
||||
# Don't re-raise - we want to continue cleanup
|
||||
|
||||
# ==========================================
|
||||
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
# we use eval_infer.sh to evaluate the agent's edits, not here
|
||||
# because the agent may alter the environment / testcases
|
||||
test_result = {
|
||||
'git_patch': git_patch,
|
||||
}
|
||||
|
||||
# If you are working on some simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
# NOTE: this is NO LONGER the event stream, but an agent history that includes delegate agent's events
|
||||
histories = [event_to_dict(event) for event in state.history]
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# Save the output
|
||||
instruction = message_action.content
|
||||
if message_action.image_urls:
|
||||
instruction += (
|
||||
'\n\n<image_urls>'
|
||||
+ '\n'.join(message_action.image_urls)
|
||||
+ '</image_urls>'
|
||||
)
|
||||
output = EvalOutput(
|
||||
instance_id=instance.instance_id,
|
||||
instruction=instruction,
|
||||
instance=instance.to_dict(), # SWE Bench specific
|
||||
test_result=test_result,
|
||||
metadata=metadata,
|
||||
history=histories,
|
||||
metrics=metrics,
|
||||
error=state.last_error if state and state.last_error else None,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
|
||||
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.toml')
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, 'r') as file:
|
||||
data = toml.load(file)
|
||||
if 'selected_ids' in data:
|
||||
selected_ids = data['selected_ids']
|
||||
logger.info(
|
||||
f'Filtering {len(selected_ids)} tasks from "selected_ids"...'
|
||||
)
|
||||
subset = dataset[dataset[filter_column].isin(selected_ids)]
|
||||
logger.info(f'Retained {subset.shape[0]} tasks after filtering')
|
||||
return subset
|
||||
if 'selected_repos' in data:
|
||||
# repos for the swe-bench instances:
|
||||
# ['astropy/astropy', 'django/django', 'matplotlib/matplotlib', 'mwaskom/seaborn', 'pallets/flask', 'psf/requests', 'pydata/xarray', 'pylint-dev/pylint', 'pytest-dev/pytest', 'scikit-learn/scikit-learn', 'sphinx-doc/sphinx', 'sympy/sympy']
|
||||
selected_repos = data['selected_repos']
|
||||
if isinstance(selected_repos, str):
|
||||
selected_repos = [selected_repos]
|
||||
assert isinstance(selected_repos, list)
|
||||
logger.info(
|
||||
f'Filtering {selected_repos} tasks from "selected_repos"...'
|
||||
)
|
||||
subset = dataset[dataset['repo'].isin(selected_repos)]
|
||||
logger.info(f'Retained {subset.shape[0]} tasks after filtering')
|
||||
return subset
|
||||
|
||||
skip_ids = os.environ.get('SKIP_IDS', '').split(',')
|
||||
if len(skip_ids) > 0:
|
||||
logger.info(f'Filtering {len(skip_ids)} tasks from "SKIP_IDS"...')
|
||||
return dataset[~dataset[filter_column].isin(skip_ids)]
|
||||
return dataset
|
||||
|
||||
|
||||
def divide_cpus_among_workers(num_workers, num_cpus_per_worker=4, num_to_skip=0):
|
||||
"""Divide CPUs among workers, with better error handling for multiprocessing."""
|
||||
try:
|
||||
current_cpus = list(os.sched_getaffinity(0))
|
||||
except AttributeError:
|
||||
# os.sched_getaffinity not available on all platforms
|
||||
import multiprocessing
|
||||
|
||||
current_cpus = list(range(multiprocessing.cpu_count()))
|
||||
|
||||
num_cpus = len(current_cpus)
|
||||
if num_workers <= 0:
|
||||
raise ValueError('Number of workers must be greater than 0')
|
||||
|
||||
# Chec that num worers and num_cpus_per_worker fit into available CPUs
|
||||
total_cpus_needed = num_workers * num_cpus_per_worker + num_to_skip
|
||||
if total_cpus_needed > num_cpus:
|
||||
raise ValueError(
|
||||
f'Not enough CPUs available. Requested {total_cpus_needed} '
|
||||
f'CPUs (num_workers={num_workers}, num_cpus_per_worker={num_cpus_per_worker}, '
|
||||
f'num_to_skip={num_to_skip}), but only {num_cpus} CPUs are available.'
|
||||
)
|
||||
|
||||
# Divide this into groups, skipping the first `num_to_skip` CPUs.
|
||||
available_cpus = current_cpus[num_to_skip:]
|
||||
cpu_groups = [
|
||||
available_cpus[i * num_cpus_per_worker : (i + 1) * num_cpus_per_worker]
|
||||
for i in range(num_workers)
|
||||
]
|
||||
print(
|
||||
f'Divided {num_cpus} CPUs into {num_workers} groups, each with {num_cpus_per_worker} CPUs.'
|
||||
)
|
||||
print(f'CPU groups: {cpu_groups}')
|
||||
|
||||
return cpu_groups
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = get_evaluation_parser()
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
type=str,
|
||||
default=None,
|
||||
help='data set to evaluate on, for now use local.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--split',
|
||||
type=str,
|
||||
default='test',
|
||||
help='split to evaluate on',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--mode',
|
||||
type=str,
|
||||
default='swe',
|
||||
help='mode to evaluate on',
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
# NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
|
||||
# so we don't need to manage file uploading to OpenHands's repo
|
||||
|
||||
# dataset = load_dataset(args.dataset, split=args.split)
|
||||
# swe_bench_tests = filter_dataset(dataset.to_pandas(), 'instance_id')
|
||||
dataset = load_dataset(args.dataset, split=args.split)
|
||||
|
||||
# Convert dataset to pandas DataFrame if it is not already.
|
||||
if not isinstance(dataset, pd.DataFrame):
|
||||
dataset = dataset.to_pandas()
|
||||
|
||||
dataset['version'] = dataset['version'].astype(str)
|
||||
|
||||
# Convert created_at column to string.
|
||||
dataset['created_at'] = dataset['created_at'].astype(str)
|
||||
|
||||
swe_bench_tests = filter_dataset(dataset, 'instance_id')
|
||||
|
||||
logger.info(
|
||||
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks'
|
||||
)
|
||||
|
||||
llm_config = None
|
||||
if args.llm_config:
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
llm_config.log_completions = True
|
||||
# modify_params must be False for evaluation purpose, for reproducibility and accurancy of results
|
||||
llm_config.modify_params = False
|
||||
|
||||
if llm_config is None:
|
||||
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
|
||||
|
||||
# Get condenser config from environment variable
|
||||
condenser_name = os.environ.get('EVAL_CONDENSER')
|
||||
if condenser_name:
|
||||
condenser_config = get_condenser_config_arg(condenser_name)
|
||||
if condenser_config is None:
|
||||
raise ValueError(
|
||||
f'Could not find Condenser config: EVAL_CONDENSER={condenser_name}'
|
||||
)
|
||||
else:
|
||||
# If no specific condenser config is provided via env var, default to NoOpCondenser
|
||||
condenser_config = NoOpCondenserConfig()
|
||||
logger.debug(
|
||||
'No Condenser config provided via EVAL_CONDENSER, using NoOpCondenser.'
|
||||
)
|
||||
|
||||
details = {'mode': args.mode}
|
||||
_agent_cls = openhands.agenthub.Agent.get_cls(args.agent_cls)
|
||||
|
||||
dataset_descrption = (
|
||||
args.dataset.replace('/', '__') + '-' + args.split.replace('/', '__')
|
||||
)
|
||||
metadata = make_metadata(
|
||||
llm_config,
|
||||
dataset_descrption,
|
||||
args.agent_cls,
|
||||
args.max_iterations,
|
||||
args.eval_note,
|
||||
args.eval_output_dir,
|
||||
details=details,
|
||||
condenser_config=condenser_config,
|
||||
)
|
||||
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
print(f'### OUTPUT FILE: {output_file} ###')
|
||||
|
||||
# Run evaluation in iterative mode:
|
||||
# If a rollout fails to output AgentFinishAction, we will try again until it succeeds OR total 3 attempts have been made.
|
||||
ITERATIVE_EVAL_MODE = (
|
||||
os.environ.get('ITERATIVE_EVAL_MODE', 'false').lower() == 'true'
|
||||
)
|
||||
ITERATIVE_EVAL_MODE_MAX_ATTEMPTS = int(
|
||||
os.environ.get('ITERATIVE_EVAL_MODE_MAX_ATTEMPTS', '3')
|
||||
)
|
||||
|
||||
# Get all CPUs and divide into groups of num_workers and put them into a multiprocessing.Queue.
|
||||
cpu_groups_queue = None
|
||||
cpu_groups_list = divide_cpus_among_workers(args.eval_num_workers, num_to_skip=8)
|
||||
cpu_groups_queue = multiprocessing.Manager().Queue()
|
||||
for cpu_group in cpu_groups_list:
|
||||
cpu_groups_queue.put(cpu_group)
|
||||
|
||||
if not ITERATIVE_EVAL_MODE:
|
||||
# load the dataset
|
||||
instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
|
||||
|
||||
process_instance_with_cpu_groups = functools.partial(
|
||||
process_instance,
|
||||
cpu_groups_queue=cpu_groups_queue,
|
||||
)
|
||||
|
||||
config = get_config(
|
||||
instances.iloc[0], # Use the first instance to get the config
|
||||
metadata,
|
||||
cpu_group=None, # We will use the cpu_groups_queue to get the cpu group later
|
||||
)
|
||||
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance_with_cpu_groups,
|
||||
timeout_seconds=8
|
||||
* 60
|
||||
* 60, # 8 hour PER instance should be more than enough
|
||||
max_retries=3,
|
||||
)
|
||||
else:
|
||||
critic = AgentFinishedCritic()
|
||||
|
||||
def get_cur_output_file_path(attempt: int) -> str:
|
||||
return (
|
||||
f'{output_file.removesuffix(".jsonl")}.critic_attempt_{attempt}.jsonl'
|
||||
)
|
||||
|
||||
eval_ids = None
|
||||
for attempt in range(1, ITERATIVE_EVAL_MODE_MAX_ATTEMPTS + 1):
|
||||
cur_output_file = get_cur_output_file_path(attempt)
|
||||
logger.info(
|
||||
f'Running evaluation with critic {critic.__class__.__name__} for attempt {attempt} of {ITERATIVE_EVAL_MODE_MAX_ATTEMPTS}.'
|
||||
)
|
||||
|
||||
# For deterministic eval, we set temperature to 0.1 for (>1) attempt
|
||||
# so hopefully we get slightly different results
|
||||
if attempt > 1 and metadata.llm_config.temperature == 0:
|
||||
logger.info(
|
||||
f'Detected temperature is 0 for (>1) attempt {attempt}. Setting temperature to 0.1...'
|
||||
)
|
||||
metadata.llm_config.temperature = 0.1
|
||||
|
||||
# Load instances - at first attempt, we evaluate all instances
|
||||
# On subsequent attempts, we only evaluate the instances that failed the previous attempt determined by critic
|
||||
instances = prepare_dataset(
|
||||
swe_bench_tests, cur_output_file, args.eval_n_limit, eval_ids=eval_ids
|
||||
)
|
||||
if len(instances) > 0 and not isinstance(
|
||||
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
|
||||
):
|
||||
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
|
||||
instances[col] = instances[col].apply(lambda x: str(x))
|
||||
|
||||
# Run evaluation - but save them to cur_output_file
|
||||
logger.info(
|
||||
f'Evaluating {len(instances)} instances for attempt {attempt}...'
|
||||
)
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata,
|
||||
cur_output_file,
|
||||
args.eval_num_workers,
|
||||
process_instance,
|
||||
timeout_seconds=8
|
||||
* 60
|
||||
* 60, # 8 hour PER instance should be more than enough
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
# When eval is done, we update eval_ids to the instances that failed the current attempt
|
||||
instances_failed = []
|
||||
logger.info(
|
||||
f'Use critic {critic.__class__.__name__} to check {len(instances)} instances for attempt {attempt}...'
|
||||
)
|
||||
with open(cur_output_file, 'r') as f:
|
||||
for line in f:
|
||||
instance = json.loads(line)
|
||||
try:
|
||||
history = [
|
||||
event_from_dict(event) for event in instance['history']
|
||||
]
|
||||
critic_result = critic.evaluate(
|
||||
history, instance['test_result'].get('git_patch', '')
|
||||
)
|
||||
if not critic_result.success:
|
||||
instances_failed.append(instance['instance_id'])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error loading history for instance {instance["instance_id"]}: {e}'
|
||||
)
|
||||
instances_failed.append(instance['instance_id'])
|
||||
logger.info(
|
||||
f'{len(instances_failed)} instances failed the current attempt {attempt}: {instances_failed}'
|
||||
)
|
||||
eval_ids = instances_failed
|
||||
|
||||
# If no instances failed, we break
|
||||
if len(instances_failed) == 0:
|
||||
break
|
||||
|
||||
# Then we should aggregate the results from all attempts into the original output file
|
||||
# and remove the intermediate files
|
||||
logger.info(
|
||||
'Aggregating results from all attempts into the original output file...'
|
||||
)
|
||||
fout = open(output_file, 'w')
|
||||
added_instance_ids = set()
|
||||
for attempt in reversed(range(1, ITERATIVE_EVAL_MODE_MAX_ATTEMPTS + 1)):
|
||||
cur_output_file = get_cur_output_file_path(attempt)
|
||||
if not os.path.exists(cur_output_file):
|
||||
logger.warning(
|
||||
f'Intermediate output file {cur_output_file} does not exist. Skipping...'
|
||||
)
|
||||
continue
|
||||
|
||||
with open(cur_output_file, 'r') as f:
|
||||
for line in f:
|
||||
instance = json.loads(line)
|
||||
# Also make sure git_patch is not empty - otherwise we fall back to previous attempt (empty patch is worse than anything else)
|
||||
if (
|
||||
instance['instance_id'] not in added_instance_ids
|
||||
and instance['test_result'].get('git_patch', '').strip()
|
||||
):
|
||||
fout.write(line)
|
||||
added_instance_ids.add(instance['instance_id'])
|
||||
logger.info(
|
||||
f'Aggregated instances from {cur_output_file}. Total instances added so far: {len(added_instance_ids)}'
|
||||
)
|
||||
fout.close()
|
||||
logger.info(
|
||||
f'Done! Total {len(added_instance_ids)} instances added to {output_file}'
|
||||
)
|
||||
148
evaluation/benchmarks/swefficiency/scripts/run_infer.sh
Executable file
148
evaluation/benchmarks/swefficiency/scripts/run_infer.sh
Executable file
@@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env bash
|
||||
set -eo pipefail
|
||||
|
||||
source "evaluation/utils/version_control.sh"
|
||||
|
||||
MODEL_CONFIG=$1
|
||||
COMMIT_HASH=$2
|
||||
AGENT=$3
|
||||
EVAL_LIMIT=$4
|
||||
MAX_ITER=$5
|
||||
NUM_WORKERS=$6
|
||||
DATASET=$7
|
||||
SPLIT=$8
|
||||
N_RUNS=$9
|
||||
MODE=${10}
|
||||
|
||||
|
||||
if [ -z "$NUM_WORKERS" ]; then
|
||||
NUM_WORKERS=1
|
||||
echo "Number of workers not specified, use default $NUM_WORKERS"
|
||||
fi
|
||||
checkout_eval_branch
|
||||
|
||||
if [ -z "$AGENT" ]; then
|
||||
echo "Agent not specified, use default CodeActAgent"
|
||||
AGENT="CodeActAgent"
|
||||
fi
|
||||
|
||||
if [ -z "$MAX_ITER" ]; then
|
||||
echo "MAX_ITER not specified, use default 100"
|
||||
MAX_ITER=100
|
||||
fi
|
||||
|
||||
if [ -z "$RUN_WITH_BROWSING" ]; then
|
||||
echo "RUN_WITH_BROWSING not specified, use default false"
|
||||
RUN_WITH_BROWSING=false
|
||||
fi
|
||||
|
||||
|
||||
if [ -z "$DATASET" ]; then
|
||||
echo "DATASET not specified, use default princeton-nlp/SWE-bench_Lite"
|
||||
DATASET="swefficiency/swefficiency"
|
||||
fi
|
||||
|
||||
if [ -z "$SPLIT" ]; then
|
||||
echo "SPLIT not specified, use default test"
|
||||
SPLIT="test"
|
||||
fi
|
||||
|
||||
if [ -z "$MODE" ]; then
|
||||
MODE="swe"
|
||||
echo "MODE not specified, use default $MODE"
|
||||
fi
|
||||
|
||||
if [ -n "$EVAL_CONDENSER" ]; then
|
||||
echo "Using Condenser Config: $EVAL_CONDENSER"
|
||||
else
|
||||
echo "No Condenser Config provided via EVAL_CONDENSER, use default (NoOpCondenser)."
|
||||
fi
|
||||
|
||||
export RUN_WITH_BROWSING=$RUN_WITH_BROWSING
|
||||
echo "RUN_WITH_BROWSING: $RUN_WITH_BROWSING"
|
||||
|
||||
get_openhands_version
|
||||
|
||||
echo "AGENT: $AGENT"
|
||||
echo "OPENHANDS_VERSION: $OPENHANDS_VERSION"
|
||||
echo "MODEL_CONFIG: $MODEL_CONFIG"
|
||||
echo "DATASET: $DATASET"
|
||||
echo "SPLIT: $SPLIT"
|
||||
echo "MAX_ITER: $MAX_ITER"
|
||||
echo "NUM_WORKERS: $NUM_WORKERS"
|
||||
echo "COMMIT_HASH: $COMMIT_HASH"
|
||||
echo "MODE: $MODE"
|
||||
echo "EVAL_CONDENSER: $EVAL_CONDENSER"
|
||||
|
||||
# Default to NOT use Hint
|
||||
if [ -z "$USE_HINT_TEXT" ]; then
|
||||
export USE_HINT_TEXT=false
|
||||
fi
|
||||
echo "USE_HINT_TEXT: $USE_HINT_TEXT"
|
||||
EVAL_NOTE="$OPENHANDS_VERSION"
|
||||
# if not using Hint, add -no-hint to the eval note
|
||||
if [ "$USE_HINT_TEXT" = false ]; then
|
||||
EVAL_NOTE="$EVAL_NOTE-no-hint"
|
||||
fi
|
||||
|
||||
if [ "$RUN_WITH_BROWSING" = true ]; then
|
||||
EVAL_NOTE="$EVAL_NOTE-with-browsing"
|
||||
fi
|
||||
|
||||
if [ -n "$EXP_NAME" ]; then
|
||||
EVAL_NOTE="$EVAL_NOTE-$EXP_NAME"
|
||||
fi
|
||||
# if mode != swe, add mode to the eval note
|
||||
if [ "$MODE" != "swe" ]; then
|
||||
EVAL_NOTE="${EVAL_NOTE}-${MODE}"
|
||||
fi
|
||||
# Add condenser config to eval note if provided
|
||||
if [ -n "$EVAL_CONDENSER" ]; then
|
||||
EVAL_NOTE="${EVAL_NOTE}-${EVAL_CONDENSER}"
|
||||
fi
|
||||
|
||||
# export RUNTIME="remote"
|
||||
# export SANDBOX_REMOTE_RUNTIME_API_URL="https://runtime.eval.all-hands.dev"
|
||||
export NO_CHANGE_TIMEOUT_SECONDS=900 # 15 minutes
|
||||
|
||||
function run_eval() {
|
||||
local eval_note="${1}"
|
||||
COMMAND="poetry run python evaluation/benchmarks/swefficiency/run_infer.py \
|
||||
--agent-cls $AGENT \
|
||||
--llm-config $MODEL_CONFIG \
|
||||
--max-iterations $MAX_ITER \
|
||||
--eval-num-workers $NUM_WORKERS \
|
||||
--eval-note $eval_note \
|
||||
--dataset $DATASET \
|
||||
--split $SPLIT \
|
||||
--mode $MODE"
|
||||
|
||||
if [ -n "$EVAL_LIMIT" ]; then
|
||||
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||
fi
|
||||
|
||||
# Run the command
|
||||
eval $COMMAND
|
||||
}
|
||||
|
||||
unset SANDBOX_ENV_GITHUB_TOKEN # prevent the agent from using the github token to push
|
||||
if [ -z "$N_RUNS" ]; then
|
||||
N_RUNS=1
|
||||
echo "N_RUNS not specified, use default $N_RUNS"
|
||||
fi
|
||||
|
||||
# Skip runs if the run number is in the SKIP_RUNS list
|
||||
# read from env variable SKIP_RUNS as a comma separated list of run numbers
|
||||
SKIP_RUNS=(${SKIP_RUNS//,/ })
|
||||
for i in $(seq 1 $N_RUNS); do
|
||||
if [[ " ${SKIP_RUNS[@]} " =~ " $i " ]]; then
|
||||
echo "Skipping run $i"
|
||||
continue
|
||||
fi
|
||||
current_eval_note="$EVAL_NOTE-run_$i"
|
||||
echo "EVAL_NOTE: $current_eval_note"
|
||||
run_eval $current_eval_note
|
||||
done
|
||||
|
||||
checkout_original_branch
|
||||
43
evaluation/benchmarks/swefficiency/scripts/setup/instance_swe_entry.sh
Executable file
43
evaluation/benchmarks/swefficiency/scripts/setup/instance_swe_entry.sh
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
source ~/.bashrc
|
||||
SWEUTIL_DIR=/swe_util
|
||||
|
||||
# FIXME: Cannot read SWE_INSTANCE_ID from the environment variable
|
||||
# SWE_INSTANCE_ID=django__django-11099
|
||||
if [ -z "$SWE_INSTANCE_ID" ]; then
|
||||
echo "Error: SWE_INSTANCE_ID is not set." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Read the swe-bench-test-lite.json file and extract the required item based on instance_id
|
||||
item=$(jq --arg INSTANCE_ID "$SWE_INSTANCE_ID" '.[] | select(.instance_id == $INSTANCE_ID)' $SWEUTIL_DIR/eval_data/instances/swe-bench-instance.json)
|
||||
|
||||
if [[ -z "$item" ]]; then
|
||||
echo "No item found for the provided instance ID."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
WORKSPACE_NAME=$(echo "$item" | jq -r '(.repo | tostring) + "__" + (.version | tostring) | gsub("/"; "__")')
|
||||
|
||||
echo "WORKSPACE_NAME: $WORKSPACE_NAME"
|
||||
|
||||
# Clear the workspace
|
||||
if [ -d /workspace ]; then
|
||||
rm -rf /workspace/*
|
||||
else
|
||||
mkdir /workspace
|
||||
fi
|
||||
# Copy repo to workspace
|
||||
if [ -d /workspace/$WORKSPACE_NAME ]; then
|
||||
rm -rf /workspace/$WORKSPACE_NAME
|
||||
fi
|
||||
mkdir -p /workspace
|
||||
cp -r /testbed /workspace/$WORKSPACE_NAME
|
||||
|
||||
# Activate instance-specific environment
|
||||
if [ -d /opt/miniconda3 ]; then
|
||||
. /opt/miniconda3/etc/profile.d/conda.sh
|
||||
conda activate testbed
|
||||
fi
|
||||
27
evaluation/benchmarks/swefficiency/scripts/setup/prepare_swe_utils.sh
Executable file
27
evaluation/benchmarks/swefficiency/scripts/setup/prepare_swe_utils.sh
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
EVAL_WORKSPACE="evaluation/benchmarks/swe_bench/eval_workspace"
|
||||
mkdir -p $EVAL_WORKSPACE
|
||||
|
||||
# 1. Prepare REPO
|
||||
echo "==== Prepare SWE-bench repo ===="
|
||||
OH_SWE_BENCH_REPO_PATH="https://github.com/All-Hands-AI/SWE-bench.git"
|
||||
OH_SWE_BENCH_REPO_BRANCH="eval"
|
||||
git clone -b $OH_SWE_BENCH_REPO_BRANCH $OH_SWE_BENCH_REPO_PATH $EVAL_WORKSPACE/OH-SWE-bench
|
||||
|
||||
# 2. Prepare DATA
|
||||
echo "==== Prepare SWE-bench data ===="
|
||||
EVAL_IMAGE=ghcr.io/all-hands-ai/eval-swe-bench:builder_with_conda
|
||||
EVAL_WORKSPACE=$(realpath $EVAL_WORKSPACE)
|
||||
chmod +x $EVAL_WORKSPACE/OH-SWE-bench/swebench/harness/prepare_data.sh
|
||||
if [ -d $EVAL_WORKSPACE/eval_data ]; then
|
||||
rm -r $EVAL_WORKSPACE/eval_data
|
||||
fi
|
||||
docker run \
|
||||
-v $EVAL_WORKSPACE:/workspace \
|
||||
-w /workspace \
|
||||
-u $(id -u):$(id -g) \
|
||||
-e HF_DATASETS_CACHE="/tmp" \
|
||||
--rm -it $EVAL_IMAGE \
|
||||
bash -c "cd OH-SWE-bench/swebench/harness && /swe_util/miniforge3/bin/conda run -n swe-bench-eval ./prepare_data.sh && mv eval_data /workspace/"
|
||||
96
evaluation/benchmarks/swefficiency/scripts/setup/swe_entry.sh
Executable file
96
evaluation/benchmarks/swefficiency/scripts/setup/swe_entry.sh
Executable file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
# assert user name is `root`
|
||||
if [ "$USER" != "root" ]; then
|
||||
echo "Error: This script is intended to be run by the 'root' user only." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source ~/.bashrc
|
||||
|
||||
SWEUTIL_DIR=/swe_util
|
||||
|
||||
# Create logs directory
|
||||
LOG_DIR=/openhands/logs
|
||||
mkdir -p $LOG_DIR && chmod 777 $LOG_DIR
|
||||
|
||||
# FIXME: Cannot read SWE_INSTANCE_ID from the environment variable
|
||||
# SWE_INSTANCE_ID=django__django-11099
|
||||
if [ -z "$SWE_INSTANCE_ID" ]; then
|
||||
echo "Error: SWE_INSTANCE_ID is not set." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Read the swe-bench-test-lite.json file and extract the required item based on instance_id
|
||||
item=$(jq --arg INSTANCE_ID "$SWE_INSTANCE_ID" '.[] | select(.instance_id == $INSTANCE_ID)' $SWEUTIL_DIR/eval_data/instances/swe-bench-test-lite.json)
|
||||
|
||||
if [[ -z "$item" ]]; then
|
||||
echo "No item found for the provided instance ID."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CONDA_ENV_NAME=$(echo "$item" | jq -r '.repo + "__" + .version | gsub("/"; "__")')
|
||||
|
||||
echo "CONDA_ENV_NAME: $CONDA_ENV_NAME"
|
||||
|
||||
SWE_TASK_DIR=/openhands/swe_tasks
|
||||
mkdir -p $SWE_TASK_DIR
|
||||
# Dump test_patch to /workspace/test.patch
|
||||
echo "$item" | jq -r '.test_patch' > $SWE_TASK_DIR/test.patch
|
||||
# Dump patch to /workspace/gold.patch
|
||||
echo "$item" | jq -r '.patch' > $SWE_TASK_DIR/gold.patch
|
||||
# Dump the item to /workspace/instance.json except for the "test_patch" and "patch" fields
|
||||
echo "$item" | jq 'del(.test_patch, .patch)' > $SWE_TASK_DIR/instance.json
|
||||
|
||||
# Clear the workspace
|
||||
rm -rf /workspace/*
|
||||
# Copy repo to workspace
|
||||
if [ -d /workspace/$CONDA_ENV_NAME ]; then
|
||||
rm -rf /workspace/$CONDA_ENV_NAME
|
||||
fi
|
||||
cp -r $SWEUTIL_DIR/eval_data/testbeds/$CONDA_ENV_NAME /workspace
|
||||
|
||||
# Reset swe-bench testbed and install the repo
|
||||
. $SWEUTIL_DIR/miniforge3/etc/profile.d/conda.sh
|
||||
conda config --set changeps1 False
|
||||
conda config --append channels conda-forge
|
||||
conda activate swe-bench-eval
|
||||
|
||||
mkdir -p $SWE_TASK_DIR/reset_testbed_temp
|
||||
mkdir -p $SWE_TASK_DIR/reset_testbed_log_dir
|
||||
SWE_BENCH_DIR=/swe_util/OH-SWE-bench
|
||||
output=$(
|
||||
export PYTHONPATH=$SWE_BENCH_DIR && \
|
||||
cd $SWE_BENCH_DIR && \
|
||||
python swebench/harness/reset_swe_env.py \
|
||||
--swe_bench_tasks $SWEUTIL_DIR/eval_data/instances/swe-bench-test.json \
|
||||
--temp_dir $SWE_TASK_DIR/reset_testbed_temp \
|
||||
--testbed /workspace \
|
||||
--conda_path $SWEUTIL_DIR/miniforge3 \
|
||||
--instance_id $SWE_INSTANCE_ID \
|
||||
--log_dir $SWE_TASK_DIR/reset_testbed_log_dir \
|
||||
--timeout 900 \
|
||||
--verbose
|
||||
)
|
||||
|
||||
REPO_PATH=$(echo "$output" | awk -F': ' '/repo_path:/ {print $2}')
|
||||
TEST_CMD=$(echo "$output" | awk -F': ' '/test_cmd:/ {print $2}')
|
||||
echo "Repo Path: $REPO_PATH"
|
||||
echo "Test Command: $TEST_CMD"
|
||||
|
||||
echo "export SWE_BENCH_DIR=\"$SWE_BENCH_DIR\"" >> ~/.bashrc
|
||||
echo "export REPO_PATH=\"$REPO_PATH\"" >> ~/.bashrc
|
||||
echo "export TEST_CMD=\"$TEST_CMD\"" >> ~/.bashrc
|
||||
|
||||
if [[ "$REPO_PATH" == "None" ]]; then
|
||||
echo "Error: Failed to retrieve repository path. Tests may not have passed or output was not as expected." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Activate instance-specific environment
|
||||
. $SWEUTIL_DIR/miniforge3/etc/profile.d/conda.sh
|
||||
conda activate $CONDA_ENV_NAME
|
||||
|
||||
set +e
|
||||
@@ -61,7 +61,7 @@ describe("ExpandableMessage", () => {
|
||||
expect(icon).toHaveClass("fill-success");
|
||||
});
|
||||
|
||||
it("should render with error icon for failed action messages", () => {
|
||||
it("should render with no icon for failed action messages", () => {
|
||||
renderWithProviders(
|
||||
<ExpandableMessage
|
||||
id="OBSERVATION_MESSAGE$RUN"
|
||||
@@ -75,8 +75,7 @@ describe("ExpandableMessage", () => {
|
||||
"div.flex.gap-2.items-center.justify-start",
|
||||
);
|
||||
expect(container).toHaveClass("border-neutral-300");
|
||||
const icon = screen.getByTestId("status-icon");
|
||||
expect(icon).toHaveClass("fill-danger");
|
||||
expect(screen.queryByTestId("status-icon")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render with neutral border and no icon for action messages without success prop", () => {
|
||||
|
||||
@@ -3,7 +3,7 @@ import { describe, expect, it, vi } from "vitest";
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { AnalyticsConsentFormModal } from "#/components/features/analytics/analytics-consent-form-modal";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
|
||||
describe("AnalyticsConsentFormModal", () => {
|
||||
it("should call saveUserSettings with consent", async () => {
|
||||
|
||||
@@ -3,7 +3,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { QueryClientProvider, QueryClient } from "@tanstack/react-query";
|
||||
import { createRoutesStub, Outlet } from "react-router";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import ConversationService from "#/api/conversation-service/conversation-service.api";
|
||||
import GitService from "#/api/git-service/git-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
@@ -404,7 +404,7 @@ describe("RepoConnector", () => {
|
||||
ConversationService,
|
||||
"createConversation",
|
||||
);
|
||||
createConversationSpy.mockImplementation(() => new Promise(() => {})); // Never resolves to keep loading state
|
||||
createConversationSpy.mockImplementation(() => new Promise(() => { })); // Never resolves to keep loading state
|
||||
const retrieveUserGitRepositoriesSpy = vi.spyOn(
|
||||
GitService,
|
||||
"retrieveUserGitRepositories",
|
||||
|
||||
@@ -3,7 +3,7 @@ import { renderWithProviders } from "test-utils";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { waitFor } from "@testing-library/react";
|
||||
import { Sidebar } from "#/components/features/sidebar/sidebar";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
|
||||
// These tests will now fail because the conversation panel is rendered through a portal
|
||||
// and technically not a child of the Sidebar component.
|
||||
|
||||
@@ -57,7 +57,7 @@ describe("MicroagentsModal - Refresh Button", () => {
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("Refresh Button Rendering", () => {
|
||||
@@ -74,13 +74,15 @@ describe("MicroagentsModal - Refresh Button", () => {
|
||||
describe("Refresh Button Functionality", () => {
|
||||
it("should call refetch when refresh button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const refreshSpy = vi.spyOn(ConversationService, "getMicroagents");
|
||||
|
||||
renderWithProviders(<MicroagentsModal {...defaultProps} />);
|
||||
|
||||
const refreshSpy = vi.spyOn(ConversationService, "getMicroagents");
|
||||
|
||||
// Wait for the component to load and render the refresh button
|
||||
const refreshButton = await screen.findByTestId("refresh-microagents");
|
||||
|
||||
refreshSpy.mockClear();
|
||||
|
||||
await user.click(refreshButton);
|
||||
|
||||
expect(refreshSpy).toHaveBeenCalledTimes(1);
|
||||
|
||||
@@ -3,7 +3,7 @@ import { describe, expect, it, vi } from "vitest";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { screen } from "@testing-library/react";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { SettingsForm } from "#/components/shared/modals/settings/settings-form";
|
||||
import { DEFAULT_SETTINGS } from "#/services/settings";
|
||||
|
||||
|
||||
@@ -1,12 +1,26 @@
|
||||
import { describe, it, expect, beforeAll, afterAll, afterEach } from "vitest";
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeAll,
|
||||
beforeEach,
|
||||
afterAll,
|
||||
afterEach,
|
||||
} from "vitest";
|
||||
import { screen, waitFor, render, cleanup } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { useOptimisticUserMessageStore } from "#/stores/optimistic-user-message-store";
|
||||
import { useBrowserStore } from "#/stores/browser-store";
|
||||
import { useCommandStore } from "#/state/command-store";
|
||||
import {
|
||||
createMockMessageEvent,
|
||||
createMockUserMessageEvent,
|
||||
createMockAgentErrorEvent,
|
||||
createMockBrowserObservationEvent,
|
||||
createMockBrowserNavigateActionEvent,
|
||||
createMockExecuteBashActionEvent,
|
||||
createMockExecuteBashObservationEvent,
|
||||
} from "#/mocks/mock-ws-helpers";
|
||||
import {
|
||||
ConnectionStatusComponent,
|
||||
@@ -461,7 +475,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
);
|
||||
|
||||
// Create a test component that displays loading state
|
||||
const HistoryLoadingComponent = () => {
|
||||
function HistoryLoadingComponent() {
|
||||
const context = useConversationWebSocket();
|
||||
const { events } = useEventStore();
|
||||
|
||||
@@ -474,7 +488,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
<div data-testid="expected-event-count">{expectedEventCount}</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(
|
||||
@@ -484,7 +498,9 @@ describe("Conversation WebSocket Handler", () => {
|
||||
);
|
||||
|
||||
// Initially should be loading history
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent("true");
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
|
||||
"true",
|
||||
);
|
||||
|
||||
// Wait for all events to be received
|
||||
await waitFor(() => {
|
||||
@@ -523,7 +539,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
);
|
||||
|
||||
// Create a test component that displays loading state
|
||||
const HistoryLoadingComponent = () => {
|
||||
function HistoryLoadingComponent() {
|
||||
const context = useConversationWebSocket();
|
||||
|
||||
return (
|
||||
@@ -533,7 +549,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(
|
||||
@@ -583,7 +599,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
);
|
||||
|
||||
// Create a test component that displays loading state
|
||||
const HistoryLoadingComponent = () => {
|
||||
function HistoryLoadingComponent() {
|
||||
const context = useConversationWebSocket();
|
||||
const { events } = useEventStore();
|
||||
|
||||
@@ -595,7 +611,7 @@ describe("Conversation WebSocket Handler", () => {
|
||||
<div data-testid="events-received">{events.length}</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(
|
||||
@@ -605,7 +621,9 @@ describe("Conversation WebSocket Handler", () => {
|
||||
);
|
||||
|
||||
// Initially should be loading history
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent("true");
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
|
||||
"true",
|
||||
);
|
||||
|
||||
// Wait for all events to be received
|
||||
await waitFor(() => {
|
||||
@@ -621,17 +639,133 @@ describe("Conversation WebSocket Handler", () => {
|
||||
});
|
||||
});
|
||||
|
||||
// 9. Terminal I/O Tests (ExecuteBashAction and ExecuteBashObservation)
|
||||
describe("Terminal I/O Integration", () => {
|
||||
it("should append command to store when ExecuteBashAction event is received", async () => {
|
||||
const { createMockExecuteBashActionEvent } = await import(
|
||||
"#/mocks/mock-ws-helpers"
|
||||
// 9. Browser State Tests (BrowserObservation)
|
||||
describe("Browser State Integration", () => {
|
||||
beforeEach(() => {
|
||||
useBrowserStore.getState().reset();
|
||||
});
|
||||
|
||||
it("should update browser store with screenshot when BrowserObservation event is received", async () => {
|
||||
// Create a mock BrowserObservation event with screenshot data
|
||||
const mockBrowserObsEvent = createMockBrowserObservationEvent(
|
||||
"base64-screenshot-data",
|
||||
"Page loaded successfully",
|
||||
);
|
||||
const { useCommandStore } = await import("#/state/command-store");
|
||||
|
||||
// Clear the command store before test
|
||||
// Set up MSW to send the event when connection is established
|
||||
mswServer.use(
|
||||
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||
server.connect();
|
||||
// Send the mock event after connection
|
||||
client.send(JSON.stringify(mockBrowserObsEvent));
|
||||
}),
|
||||
);
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(<ConnectionStatusComponent />);
|
||||
|
||||
// Wait for connection
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||
"OPEN",
|
||||
);
|
||||
});
|
||||
|
||||
// Wait for the browser store to be updated with screenshot
|
||||
await waitFor(() => {
|
||||
const { screenshotSrc } = useBrowserStore.getState();
|
||||
expect(screenshotSrc).toBe(
|
||||
"data:image/png;base64,base64-screenshot-data",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should update browser store with URL when BrowserNavigateAction followed by BrowserObservation", async () => {
|
||||
// Create mock events - action first, then observation
|
||||
const mockBrowserActionEvent = createMockBrowserNavigateActionEvent(
|
||||
"https://example.com/test-page",
|
||||
);
|
||||
const mockBrowserObsEvent = createMockBrowserObservationEvent(
|
||||
"base64-screenshot-data",
|
||||
"Page loaded successfully",
|
||||
);
|
||||
|
||||
// Set up MSW to send both events when connection is established
|
||||
mswServer.use(
|
||||
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||
server.connect();
|
||||
// Send action first, then observation
|
||||
client.send(JSON.stringify(mockBrowserActionEvent));
|
||||
client.send(JSON.stringify(mockBrowserObsEvent));
|
||||
}),
|
||||
);
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(<ConnectionStatusComponent />);
|
||||
|
||||
// Wait for connection
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||
"OPEN",
|
||||
);
|
||||
});
|
||||
|
||||
// Wait for the browser store to be updated with both screenshot and URL
|
||||
await waitFor(() => {
|
||||
const { screenshotSrc, url } = useBrowserStore.getState();
|
||||
expect(screenshotSrc).toBe(
|
||||
"data:image/png;base64,base64-screenshot-data",
|
||||
);
|
||||
expect(url).toBe("https://example.com/test-page");
|
||||
});
|
||||
});
|
||||
|
||||
it("should not update browser store when BrowserObservation has no screenshot data", async () => {
|
||||
const initialScreenshot = useBrowserStore.getState().screenshotSrc;
|
||||
|
||||
// Create a mock BrowserObservation event WITHOUT screenshot data
|
||||
const mockBrowserObsEvent = createMockBrowserObservationEvent(
|
||||
null, // no screenshot
|
||||
"Browser action completed",
|
||||
);
|
||||
|
||||
// Set up MSW to send the event when connection is established
|
||||
mswServer.use(
|
||||
wsLink.addEventListener("connection", ({ client, server }) => {
|
||||
server.connect();
|
||||
// Send the mock event after connection
|
||||
client.send(JSON.stringify(mockBrowserObsEvent));
|
||||
}),
|
||||
);
|
||||
|
||||
// Render with WebSocket context
|
||||
renderWithWebSocketContext(<ConnectionStatusComponent />);
|
||||
|
||||
// Wait for connection
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("connection-state")).toHaveTextContent(
|
||||
"OPEN",
|
||||
);
|
||||
});
|
||||
|
||||
// Give some time for any potential updates
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, 100);
|
||||
});
|
||||
|
||||
// Screenshot should remain unchanged (empty/initial value)
|
||||
const { screenshotSrc } = useBrowserStore.getState();
|
||||
expect(screenshotSrc).toBe(initialScreenshot);
|
||||
});
|
||||
});
|
||||
|
||||
// 10. Terminal I/O Tests (ExecuteBashAction and ExecuteBashObservation)
|
||||
describe("Terminal I/O Integration", () => {
|
||||
beforeEach(() => {
|
||||
useCommandStore.getState().clearTerminal();
|
||||
});
|
||||
|
||||
it("should append command to store when ExecuteBashAction event is received", async () => {
|
||||
// Create a mock ExecuteBashAction event
|
||||
const mockBashActionEvent = createMockExecuteBashActionEvent("npm test");
|
||||
|
||||
@@ -667,14 +801,6 @@ describe("Conversation WebSocket Handler", () => {
|
||||
});
|
||||
|
||||
it("should append output to store when ExecuteBashObservation event is received", async () => {
|
||||
const { createMockExecuteBashObservationEvent } = await import(
|
||||
"#/mocks/mock-ws-helpers"
|
||||
);
|
||||
const { useCommandStore } = await import("#/state/command-store");
|
||||
|
||||
// Clear the command store before test
|
||||
useCommandStore.getState().clearTerminal();
|
||||
|
||||
// Create a mock ExecuteBashObservation event
|
||||
const mockBashObservationEvent = createMockExecuteBashObservationEvent(
|
||||
"PASS tests/example.test.js\n ✓ should work (2 ms)",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { useSaveSettings } from "#/hooks/mutation/use-save-settings";
|
||||
|
||||
describe("useSaveSettings", () => {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
/* eslint-disable max-classes-per-file */
|
||||
import { beforeAll, describe, expect, it, vi, afterEach } from "vitest";
|
||||
import { useTerminal } from "#/hooks/use-terminal";
|
||||
import { Command, useCommandStore } from "#/state/command-store";
|
||||
@@ -45,17 +46,29 @@ describe("useTerminal", () => {
|
||||
}));
|
||||
|
||||
beforeAll(() => {
|
||||
// mock ResizeObserver
|
||||
window.ResizeObserver = vi.fn().mockImplementation(() => ({
|
||||
observe: vi.fn(),
|
||||
unobserve: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
}));
|
||||
// mock ResizeObserver - use class for Vitest 4 constructor support
|
||||
window.ResizeObserver = class {
|
||||
observe = vi.fn();
|
||||
|
||||
// mock Terminal
|
||||
unobserve = vi.fn();
|
||||
|
||||
disconnect = vi.fn();
|
||||
} as unknown as typeof ResizeObserver;
|
||||
|
||||
// mock Terminal - use class for Vitest 4 constructor support
|
||||
vi.mock("@xterm/xterm", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("@xterm/xterm")>()),
|
||||
Terminal: vi.fn().mockImplementation(() => mockTerminal),
|
||||
Terminal: class {
|
||||
loadAddon = mockTerminal.loadAddon;
|
||||
|
||||
open = mockTerminal.open;
|
||||
|
||||
write = mockTerminal.write;
|
||||
|
||||
writeln = mockTerminal.writeln;
|
||||
|
||||
dispose = mockTerminal.dispose;
|
||||
},
|
||||
}));
|
||||
});
|
||||
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
/**
|
||||
* TODO: Fix flaky WebSocket tests (https://github.com/OpenHands/OpenHands/issues/11944)
|
||||
*
|
||||
* Several tests in this file are skipped because they fail intermittently in CI
|
||||
* but pass locally. The SUSPECTED root cause is that `wsLink.broadcast()` sends messages
|
||||
* to ALL connected clients across all tests, causing cross-test contamination
|
||||
* when tests run in parallel with Vitest v4.
|
||||
*/
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import {
|
||||
describe,
|
||||
@@ -51,7 +59,7 @@ describe("useWebSocket", () => {
|
||||
expect(result.current.socket).toBeTruthy();
|
||||
});
|
||||
|
||||
it("should handle incoming messages correctly", async () => {
|
||||
it.skip("should handle incoming messages correctly", async () => {
|
||||
const { result } = renderHook(() => useWebSocket("ws://acme.com/ws"));
|
||||
|
||||
// Wait for connection to be established
|
||||
@@ -114,7 +122,7 @@ describe("useWebSocket", () => {
|
||||
expect(result.current.socket).toBeTruthy();
|
||||
});
|
||||
|
||||
it("should close the WebSocket connection on unmount", async () => {
|
||||
it.skip("should close the WebSocket connection on unmount", async () => {
|
||||
const { result, unmount } = renderHook(() =>
|
||||
useWebSocket("ws://acme.com/ws"),
|
||||
);
|
||||
@@ -204,7 +212,7 @@ describe("useWebSocket", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should call onMessage handler when WebSocket receives a message", async () => {
|
||||
it.skip("should call onMessage handler when WebSocket receives a message", async () => {
|
||||
const onMessageSpy = vi.fn();
|
||||
const options = { onMessage: onMessageSpy };
|
||||
|
||||
@@ -271,7 +279,7 @@ describe("useWebSocket", () => {
|
||||
expect(onErrorSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should provide sendMessage function to send messages to WebSocket", async () => {
|
||||
it.skip("should provide sendMessage function to send messages to WebSocket", async () => {
|
||||
const { result } = renderHook(() => useWebSocket("ws://acme.com/ws"));
|
||||
|
||||
// Wait for connection to be established
|
||||
|
||||
@@ -10,7 +10,7 @@ import MainApp from "#/routes/root-layout";
|
||||
import i18n from "#/i18n";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import * as CaptureConsent from "#/utils/handle-capture-consent";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import * as ToastHandlers from "#/utils/custom-toast-handlers";
|
||||
|
||||
describe("frontend/routes/_oh", () => {
|
||||
|
||||
@@ -3,7 +3,7 @@ import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import AppSettingsScreen from "#/routes/app-settings";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { AvailableLanguages } from "#/i18n";
|
||||
import * as CaptureConsent from "#/utils/handle-capture-consent";
|
||||
|
||||
@@ -6,7 +6,7 @@ import userEvent from "@testing-library/user-event";
|
||||
import i18next from "i18next";
|
||||
import { I18nextProvider } from "react-i18next";
|
||||
import GitSettingsScreen from "#/routes/git-settings";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import AuthService from "#/api/auth-service/auth-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
|
||||
@@ -6,7 +6,7 @@ import { createRoutesStub } from "react-router";
|
||||
import { createAxiosNotFoundErrorObject } from "test-utils";
|
||||
import HomeScreen from "#/routes/home";
|
||||
import { GitRepository } from "#/types/git";
|
||||
import SettingsService from "#/settings-service/settings-service.api";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import GitService from "#/api/git-service/git-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import MainApp from "#/routes/root-layout";
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user