Compare commits

...

26 Commits

Author SHA1 Message Date
openhands
9e9a0bbe87 Optimize Playwright install by removing redundant --with-deps flag
The --with-deps flag causes Playwright to install system dependencies via apt-get,
but these dependencies are already being installed manually in the lines above.
This removal avoids redundant package installation and significantly speeds up
the Docker build process.

Added libcups2t64/libcups2 to the manual dependency list as it's required by
Playwright but wasn't previously included.

Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-03 06:30:47 +00:00
openhands
c9af5edad9 Optimize Playwright install by removing redundant --with-deps flag
The --with-deps flag causes Playwright to install system dependencies via apt-get,
but these dependencies (libnss3, libnspr4, libatk-bridge2.0-0, etc.) are already
being installed manually in the lines above. This removal avoids redundant package
installation and significantly speeds up the Docker build process.

Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-03 06:11:04 +00:00
chuckbutkus
0c7ce4ad48 V1 Changes to Support Path Based Routing (#13120)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-02 22:37:37 -05:00
Rohit Malhotra
4dab34e7b0 fix(enterprise): fix type errors - missing returns and async interface (#13145)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-03 00:37:22 +00:00
Rohit Malhotra
f8bbd352a9 Fix typing: make Message a dict instead of dict | str (#13144)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-03 00:30:22 +00:00
Tim O'Farrell
17347a95f8 Make load_org_token and store_org_token async in TokenManager (#13147)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-02 17:08:21 -07:00
Graham Neubig
01ef87aaaa Add logging when sandbox is assigned to conversation (#13143)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-02 17:36:49 -05:00
Hiep Le
8059c18b57 fix(backend): update planning agent to direct users to the build button instead of asking ready to proceed (#13139) 2026-03-03 03:31:29 +07:00
Tim O'Farrell
c82ee4c7db refactor(enterprise): use async database sessions in feedback routes (#13137)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-02 13:17:44 -07:00
Tim O'Farrell
7fdb423f99 feat(enterprise): convert DeviceCodeStore to async (#13136)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-02 12:56:41 -07:00
dependabot[bot]
530065dfa7 chore(deps): bump pillow from 12.1.0 to 12.1.1 in uv lock and enterprise poetry lock (#13101)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-02 13:56:13 -06:00
Hiep Le
a4cd2d81a5 fix(backend): use run_coroutine_threadsafe for conversation update callbacks (#13134) 2026-03-03 02:07:32 +07:00
Tim O'Farrell
003b430e96 Refactor: Migrate remaining enterprise modules to async database sessions (#13124)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-02 11:52:00 -07:00
Graham Neubig
d63565186e Add Claude Opus 4.6 model support (#12767)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: neubig <neubig@users.noreply.github.com>
2026-03-02 13:12:48 -05:00
Hiep Le
5f42d03ec5 fix(backend): jira cloud integration does not work (#13123) 2026-03-02 22:05:29 +07:00
Mohammed Abdulai
62241e2e00 Fix: OSS suggested tasks empty state (#12563)
Co-authored-by: Mohammed Abdulai <nurud43@gmail.com>
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2026-03-02 18:45:29 +07:00
Neha Prasad
f5197bd76a fix: prevent double scrollbar when profile avatar popover is shown (#13115)
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2026-03-02 18:14:04 +07:00
Tim O'Farrell
e1408f7b15 Add timeout to Keycloak operations and convert OfflineTokenStore to async (#13096)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-02 01:48:45 -07:00
Shruti1128
d6b8d80026 Remove unused subscription-related frontend code (#12557) 2026-03-01 21:14:00 +01:00
Hiep Le
1e6a92b454 feat(backend): organizations llm settings api (org project) (#13108) 2026-03-02 00:06:37 +07:00
Hiep Le
b4a3e5db2f feat(backend): saas – organizations app settings api (#13022) 2026-03-01 23:26:39 +07:00
Chris Bagwell
f9d553d0bb Pass container port instead of host port to Docker (#12595)
Co-authored-by: Engel Nyst <engel.nyst@gmail.com>
2026-02-28 17:45:16 +01:00
Tim O'Farrell
f6f6c1ab25 refactor: use SQL filtering and pagination in VerifiedModelStore (#13068)
Co-authored-by: bittoby <brianwhitedev1996@gmail.com>
Co-authored-by: statxc <statxc@user.noreply.github.com>
Co-authored-by: bittoby <bittoby@users.noreply.github.com>
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-28 07:37:11 -07:00
Hiep Le
c511a89426 feat(frontend): display Bitbucket signup disabled message on login page (#13100) 2026-02-28 19:26:16 +07:00
HeyItsChloe
1f82ff04d9 feat(frontend): SaaS NUE profile questions /Onboarding flow (#13029)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2026-02-28 13:27:22 +07:00
HeyItsChloe
eec17311c7 fix(frontend): bitbucket icon color (#13106)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-28 12:12:45 +07:00
177 changed files with 9620 additions and 4280 deletions

View File

@@ -193,14 +193,20 @@ class GithubManager(Manager):
github_view.installation_id
)
# Store the installation token
self.token_manager.store_org_token(
await self.token_manager.store_org_token(
github_view.installation_id, installation_token
)
# Add eyes reaction to acknowledge we've read the request
self._add_reaction(github_view, 'eyes', installation_token)
await self.start_job(github_view)
async def send_message(self, message: Message, github_view: ResolverViewInterface):
async def send_message(self, message: str, github_view: ResolverViewInterface):
"""Send a message to GitHub.
Args:
message: The message content to send (plain text string)
github_view: The GitHub view object containing issue/PR/comment info
"""
installation_token = self.token_manager.load_org_token(
github_view.installation_id
)
@@ -208,14 +214,12 @@ class GithubManager(Manager):
logger.warning('Missing installation token')
return
outgoing_message = message.message
if isinstance(github_view, GithubInlinePRComment):
with Github(auth=Auth.Token(installation_token)) as github_client:
repo = github_client.get_repo(github_view.full_repo_name)
pr = repo.get_pull(github_view.issue_number)
pr.create_review_comment_reply(
comment_id=github_view.comment_id, body=outgoing_message
comment_id=github_view.comment_id, body=message
)
elif (
@@ -226,7 +230,7 @@ class GithubManager(Manager):
with Github(auth=Auth.Token(installation_token)) as github_client:
repo = github_client.get_repo(github_view.full_repo_name)
issue = repo.get_issue(number=github_view.issue_number)
issue.create_comment(outgoing_message)
issue.create_comment(message)
else:
logger.warning('Unsupported location')
@@ -245,7 +249,7 @@ class GithubManager(Manager):
)
try:
msg_info = None
msg_info: str = ''
try:
user_info = github_view.user_info
@@ -361,15 +365,13 @@ class GithubManager(Manager):
msg_info = get_session_expired_message(user_info.username)
msg = self.create_outgoing_message(msg_info)
await self.send_message(msg, github_view)
await self.send_message(msg_info, github_view)
except Exception:
logger.exception('[Github]: Error starting job')
msg = self.create_outgoing_message(
msg='Uh oh! There was an unexpected error starting the job :('
await self.send_message(
'Uh oh! There was an unexpected error starting the job :(', github_view
)
await self.send_message(msg, github_view)
try:
await self.data_collector.save_data(github_view)

View File

@@ -14,7 +14,6 @@ from integrations.solvability.models.summary import SolvabilitySummary
from integrations.utils import ENABLE_SOLVABILITY_ANALYSIS
from pydantic import ValidationError
from server.config import get_config
from storage.database import session_maker
from storage.saas_settings_store import SaasSettingsStore
from openhands.core.config import LLMConfig
@@ -90,7 +89,6 @@ async def summarize_issue_solvability(
# Grab the user's information so we can load their LLM configuration
store = SaasSettingsStore(
user_id=github_view.user_info.keycloak_user_id,
session_maker=session_maker,
config=get_config(),
)

View File

@@ -24,7 +24,6 @@ from jinja2 import Environment
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
from server.auth.token_manager import TokenManager
from server.config import get_config
from storage.database import session_maker
from storage.org_store import OrgStore
from storage.proactive_conversation_store import ProactiveConversationStore
from storage.saas_secrets_store import SaasSecretsStore
@@ -153,9 +152,7 @@ class GithubIssue(ResolverViewInterface):
return user_instructions, conversation_instructions
async def _get_user_secrets(self):
secrets_store = SaasSecretsStore(
self.user_info.keycloak_user_id, session_maker, get_config()
)
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
user_secrets = await secrets_store.load()
return user_secrets.custom_secrets if user_secrets else None

View File

@@ -121,12 +121,11 @@ class GitlabManager(Manager):
# Check if the user has write access to the repository
return has_write_access
async def send_message(self, message: Message, gitlab_view: ResolverViewInterface):
"""
Send a message to GitLab based on the view type.
async def send_message(self, message: str, gitlab_view: ResolverViewInterface):
"""Send a message to GitLab based on the view type.
Args:
message: The message to send
message: The message content to send (plain text string)
gitlab_view: The GitLab view object containing issue/PR/comment info
"""
keycloak_user_id = gitlab_view.user_info.keycloak_user_id
@@ -138,8 +137,6 @@ class GitlabManager(Manager):
external_auth_id=keycloak_user_id
)
outgoing_message = message.message
if isinstance(gitlab_view, GitlabInlineMRComment) or isinstance(
gitlab_view, GitlabMRComment
):
@@ -147,7 +144,7 @@ class GitlabManager(Manager):
gitlab_view.project_id,
gitlab_view.issue_number,
gitlab_view.discussion_id,
message.message,
message,
)
elif isinstance(gitlab_view, GitlabIssueComment):
@@ -155,14 +152,14 @@ class GitlabManager(Manager):
gitlab_view.project_id,
gitlab_view.issue_number,
gitlab_view.discussion_id,
outgoing_message,
message,
)
elif isinstance(gitlab_view, GitlabIssue):
await gitlab_service.reply_to_issue(
gitlab_view.project_id,
gitlab_view.issue_number,
None, # no discussion id, issue is tagged
outgoing_message,
message,
)
else:
logger.warning(
@@ -262,12 +259,10 @@ class GitlabManager(Manager):
msg_info = get_session_expired_message(user_info.username)
# Send the acknowledgment message
msg = self.create_outgoing_message(msg_info)
await self.send_message(msg, gitlab_view)
await self.send_message(msg_info, gitlab_view)
except Exception as e:
logger.exception(f'[GitLab] Error starting job: {str(e)}')
msg = self.create_outgoing_message(
msg='Uh oh! There was an unexpected error starting the job :('
await self.send_message(
'Uh oh! There was an unexpected error starting the job :(', gitlab_view
)
await self.send_message(msg, gitlab_view)

View File

@@ -6,7 +6,6 @@ from integrations.utils import HOST, get_oh_labels, has_exact_mention
from jinja2 import Environment
from server.auth.token_manager import TokenManager
from server.config import get_config
from storage.database import session_maker
from storage.saas_secrets_store import SaasSecretsStore
from openhands.core.logger import openhands_logger as logger
@@ -78,9 +77,7 @@ class GitlabIssue(ResolverViewInterface):
return user_instructions, conversation_instructions
async def _get_user_secrets(self):
secrets_store = SaasSecretsStore(
self.user_info.keycloak_user_id, session_maker, get_config()
)
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
user_secrets = await secrets_store.load()
return user_secrets.custom_secrets if user_secrets else None
@@ -449,3 +446,5 @@ class GitlabFactory:
previous_comments=[],
is_mr=True,
)
raise ValueError(f'Unhandled GitLab webhook event: {message}')

View File

@@ -341,17 +341,25 @@ class JiraManager(Manager):
async def send_message(
self,
message: Message,
message: str,
issue_key: str,
jira_cloud_id: str,
svc_acc_email: str,
svc_acc_api_key: str,
):
"""Send a comment to a Jira issue."""
"""Send a comment to a Jira issue.
Args:
message: The message content to send (plain text string)
issue_key: The Jira issue key (e.g., 'PROJ-123')
jira_cloud_id: The Jira Cloud ID
svc_acc_email: Service account email for authentication
svc_acc_api_key: Service account API key for authentication
"""
url = (
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
)
data = {'body': message.message}
data = {'body': message}
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
response = await client.post(
url, auth=(svc_acc_email, svc_acc_api_key), json=data
@@ -366,7 +374,7 @@ class JiraManager(Manager):
view.jira_workspace.svc_acc_api_key
)
await self.send_message(
self.create_outgoing_message(msg=msg),
msg,
issue_key=view.payload.issue_key,
jira_cloud_id=view.jira_workspace.jira_cloud_id,
svc_acc_email=view.jira_workspace.svc_acc_email,
@@ -388,7 +396,7 @@ class JiraManager(Manager):
try:
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
await self.send_message(
self.create_outgoing_message(msg=error_msg),
error_msg,
issue_key=payload.issue_key,
jira_cloud_id=workspace.jira_cloud_id,
svc_acc_email=workspace.svc_acc_email,

View File

@@ -212,8 +212,6 @@ class JiraPayloadParser:
missing.append('issue.id')
if not issue_key:
missing.append('issue.key')
if not user_email:
missing.append('user.emailAddress')
if not display_name:
missing.append('user.displayName')
if not account_id:

View File

@@ -418,7 +418,7 @@ class JiraDcManager(Manager):
jira_dc_view.jira_dc_workspace.svc_acc_api_key
)
await self.send_message(
self.create_outgoing_message(msg=msg_info),
msg_info,
issue_key=jira_dc_view.job_context.issue_key,
base_api_url=jira_dc_view.job_context.base_api_url,
svc_acc_api_key=api_key,
@@ -456,12 +456,19 @@ class JiraDcManager(Manager):
return title, description
async def send_message(
self, message: Message, issue_key: str, base_api_url: str, svc_acc_api_key: str
self, message: str, issue_key: str, base_api_url: str, svc_acc_api_key: str
):
"""Send message/comment to Jira DC issue."""
"""Send message/comment to Jira DC issue.
Args:
message: The message content to send (plain text string)
issue_key: The Jira issue key (e.g., 'PROJ-123')
base_api_url: The base API URL for the Jira DC instance
svc_acc_api_key: Service account API key for authentication
"""
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
data = {'body': message.message}
data = {'body': message}
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
response = await client.post(url, headers=headers, json=data)
response.raise_for_status()
@@ -481,7 +488,7 @@ class JiraDcManager(Manager):
try:
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
await self.send_message(
self.create_outgoing_message(msg=error_msg),
error_msg,
issue_key=job_context.issue_key,
base_api_url=job_context.base_api_url,
svc_acc_api_key=api_key,
@@ -502,7 +509,7 @@ class JiraDcManager(Manager):
)
await self.send_message(
self.create_outgoing_message(msg=comment_msg),
comment_msg,
issue_key=jira_dc_view.job_context.issue_key,
base_api_url=jira_dc_view.job_context.base_api_url,
svc_acc_api_key=api_key,

View File

@@ -19,7 +19,7 @@ class JiraDcViewInterface(ABC):
conversation_id: str
@abstractmethod
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Get initial instructions for the conversation."""
pass

View File

@@ -36,7 +36,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
selected_repo: str | None
conversation_id: str
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
instructions_template = jinja_env.get_template('jira_dc_instructions.j2')
@@ -61,7 +61,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
provider_tokens = await self.saas_user_auth.get_provider_tokens()
user_secrets = await self.saas_user_auth.get_secrets()
instructions, user_msg = self._get_instructions(jinja_env)
instructions, user_msg = await self._get_instructions(jinja_env)
try:
agent_loop_info = await create_new_conversation(
@@ -113,7 +113,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
selected_repo: str | None
conversation_id: str
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
user_msg_template = jinja_env.get_template('jira_dc_existing_conversation.j2')
@@ -167,7 +167,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
if not agent_state or agent_state == AgentState.LOADING:
raise StartingConvoException('Conversation is still starting')
_, user_msg = self._get_instructions(jinja_env)
_, user_msg = await self._get_instructions(jinja_env)
user_message_event = MessageAction(content=user_msg)
await conversation_manager.send_event_to_conversation(
self.conversation_id, event_to_dict(user_message_event)

View File

@@ -408,7 +408,7 @@ class LinearManager(Manager):
linear_view.linear_workspace.svc_acc_api_key
)
await self.send_message(
self.create_outgoing_message(msg=msg_info),
msg_info,
linear_view.job_context.issue_id,
api_key,
)
@@ -473,8 +473,14 @@ class LinearManager(Manager):
return title, description
async def send_message(self, message: Message, issue_id: str, api_key: str):
"""Send message/comment to Linear issue."""
async def send_message(self, message: str, issue_id: str, api_key: str):
"""Send message/comment to Linear issue.
Args:
message: The message content to send (plain text string)
issue_id: The Linear issue ID to comment on
api_key: The Linear API key for authentication
"""
query = """
mutation CommentCreate($input: CommentCreateInput!) {
commentCreate(input: $input) {
@@ -485,7 +491,7 @@ class LinearManager(Manager):
}
}
"""
variables = {'input': {'issueId': issue_id, 'body': message.message}}
variables = {'input': {'issueId': issue_id, 'body': message}}
return await self._query_api(query, variables, api_key)
async def _send_error_comment(
@@ -498,9 +504,7 @@ class LinearManager(Manager):
try:
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
await self.send_message(
self.create_outgoing_message(msg=error_msg), issue_id, api_key
)
await self.send_message(error_msg, issue_id, api_key)
except Exception as e:
logger.error(f'[Linear] Failed to send error comment: {str(e)}')
@@ -517,7 +521,7 @@ class LinearManager(Manager):
)
await self.send_message(
self.create_outgoing_message(msg=comment_msg),
comment_msg,
linear_view.job_context.issue_id,
api_key,
)

View File

@@ -19,7 +19,7 @@ class LinearViewInterface(ABC):
conversation_id: str
@abstractmethod
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Get initial instructions for the conversation."""
pass

View File

@@ -33,7 +33,7 @@ class LinearNewConversationView(LinearViewInterface):
selected_repo: str | None
conversation_id: str
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
instructions_template = jinja_env.get_template('linear_instructions.j2')
@@ -58,7 +58,7 @@ class LinearNewConversationView(LinearViewInterface):
provider_tokens = await self.saas_user_auth.get_provider_tokens()
user_secrets = await self.saas_user_auth.get_secrets()
instructions, user_msg = self._get_instructions(jinja_env)
instructions, user_msg = await self._get_instructions(jinja_env)
try:
agent_loop_info = await create_new_conversation(
@@ -110,7 +110,7 @@ class LinearExistingConversationView(LinearViewInterface):
selected_repo: str | None
conversation_id: str
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
user_msg_template = jinja_env.get_template('linear_existing_conversation.j2')
@@ -164,7 +164,7 @@ class LinearExistingConversationView(LinearViewInterface):
if not agent_state or agent_state == AgentState.LOADING:
raise StartingConvoException('Conversation is still starting')
_, user_msg = self._get_instructions(jinja_env)
_, user_msg = await self._get_instructions(jinja_env)
user_message_event = MessageAction(content=user_msg)
await conversation_manager.send_event_to_conversation(
self.conversation_id, event_to_dict(user_message_event)

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from integrations.models import Message, SourceType
@@ -12,14 +13,15 @@ class Manager(ABC):
raise NotImplementedError
@abstractmethod
def send_message(self, message: Message):
"Send message to integration from Openhands server"
def send_message(self, message: str, *args: Any, **kwargs: Any):
"""Send message to integration from OpenHands server.
Args:
message: The message content to send (plain text string).
"""
raise NotImplementedError
@abstractmethod
def start_job(self):
"Kick off a job with openhands agent"
raise NotImplementedError
def create_outgoing_message(self, msg: str | dict, ephemeral: bool = False):
return Message(source=SourceType.OPENHANDS, message=msg, ephemeral=ephemeral)

View File

@@ -1,4 +1,5 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
@@ -16,8 +17,16 @@ class SourceType(str, Enum):
class Message(BaseModel):
"""Message model for incoming webhook payloads from integrations.
Note: This model is intended for INCOMING messages only.
For outgoing messages (e.g., sending comments to GitHub/GitLab),
pass strings directly to the send_message methods instead of
wrapping them in a Message object.
"""
source: SourceType
message: str | dict
message: dict[str, Any]
ephemeral: bool = False

View File

@@ -1,4 +1,5 @@
import re
from typing import Any
import jwt
from integrations.manager import Manager
@@ -22,7 +23,8 @@ from server.constants import SLACK_CLIENT_ID
from server.utils.conversation_callback_utils import register_callback_processor
from slack_sdk.oauth import AuthorizeUrlGenerator
from slack_sdk.web.async_client import AsyncWebClient
from storage.database import session_maker
from sqlalchemy import select
from storage.database import a_session_maker
from storage.slack_user import SlackUser
from openhands.core.logger import openhands_logger as logger
@@ -63,12 +65,11 @@ class SlackManager(Manager):
) -> tuple[SlackUser | None, UserAuth | None]:
# We get the user and correlate them back to a user in OpenHands - if we can
slack_user = None
with session_maker() as session:
slack_user = (
session.query(SlackUser)
.filter(SlackUser.slack_user_id == slack_user_id)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(SlackUser).where(SlackUser.slack_user_id == slack_user_id)
)
slack_user = result.scalar_one_or_none()
# slack_view.slack_to_openhands_user = slack_user # attach user auth info to view
@@ -202,9 +203,7 @@ class SlackManager(Manager):
msg = self.login_link.format(link)
logger.info('slack_not_yet_authenticated')
await self.send_message(
self.create_outgoing_message(msg, ephemeral=True), slack_view
)
await self.send_message(msg, slack_view, ephemeral=True)
return
if not await self.is_job_requested(message, slack_view):
@@ -212,27 +211,40 @@ class SlackManager(Manager):
await self.start_job(slack_view)
async def send_message(self, message: Message, slack_view: SlackViewInterface):
async def send_message(
self,
message: str | dict[str, Any],
slack_view: SlackViewInterface,
ephemeral: bool = False,
):
"""Send a message to Slack.
Args:
message: The message content. Can be a string (for simple text) or
a dict with 'text' and 'blocks' keys (for structured messages).
slack_view: The Slack view object containing channel/thread info.
ephemeral: If True, send as an ephemeral message visible only to the user.
"""
client = AsyncWebClient(token=slack_view.bot_access_token)
if message.ephemeral and isinstance(message.message, str):
if ephemeral and isinstance(message, str):
await client.chat_postEphemeral(
channel=slack_view.channel_id,
markdown_text=message.message,
markdown_text=message,
user=slack_view.slack_user_id,
thread_ts=slack_view.thread_ts,
)
elif message.ephemeral and isinstance(message.message, dict):
elif ephemeral and isinstance(message, dict):
await client.chat_postEphemeral(
channel=slack_view.channel_id,
user=slack_view.slack_user_id,
thread_ts=slack_view.thread_ts,
text=message.message['text'],
blocks=message.message['blocks'],
text=message['text'],
blocks=message['blocks'],
)
else:
await client.chat_postMessage(
channel=slack_view.channel_id,
markdown_text=message.message,
markdown_text=message,
thread_ts=slack_view.message_ts,
)
@@ -279,10 +291,7 @@ class SlackManager(Manager):
repos, slack_view.message_ts, slack_view.thread_ts
),
}
await self.send_message(
self.create_outgoing_message(repo_selection_msg, ephemeral=True),
slack_view,
)
await self.send_message(repo_selection_msg, slack_view, ephemeral=True)
return False
@@ -368,9 +377,10 @@ class SlackManager(Manager):
except StartingConvoException as e:
msg_info = str(e)
await self.send_message(self.create_outgoing_message(msg_info), slack_view)
await self.send_message(msg_info, slack_view)
except Exception:
logger.exception('[Slack]: Error starting job')
msg = 'Uh oh! There was an unexpected error starting the job :('
await self.send_message(self.create_outgoing_message(msg), slack_view)
await self.send_message(
'Uh oh! There was an unexpected error starting the job :(', slack_view
)

View File

@@ -24,7 +24,7 @@ class SlackViewInterface(SummaryExtractionTracker, ABC):
v1_enabled: bool
@abstractmethod
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
pass

View File

@@ -75,7 +75,7 @@ class SlackUnkownUserView(SlackViewInterface):
team_id: str
v1_enabled: bool
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
raise NotImplementedError
async def create_or_update_conversation(self, jinja_env: Environment):
@@ -118,7 +118,7 @@ class SlackNewConversationView(SlackViewInterface):
return block['user_id']
return ''
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
user_info: SlackUser = self.slack_to_openhands_user
@@ -242,7 +242,9 @@ class SlackNewConversationView(SlackViewInterface):
self, jinja: Environment, provider_tokens, user_secrets
) -> None:
"""Create conversation using the legacy V0 system."""
user_instructions, conversation_instructions = self._get_instructions(jinja)
user_instructions, conversation_instructions = await self._get_instructions(
jinja
)
# Determine git provider from repository
git_provider = None
@@ -273,7 +275,9 @@ class SlackNewConversationView(SlackViewInterface):
async def _create_v1_conversation(self, jinja: Environment) -> None:
"""Create conversation using the new V1 app conversation system."""
user_instructions, conversation_instructions = self._get_instructions(jinja)
user_instructions, conversation_instructions = await self._get_instructions(
jinja
)
# Create the initial message request
initial_message = SendMessageRequest(
@@ -346,7 +350,7 @@ class SlackNewConversationFromRepoFormView(SlackNewConversationView):
class SlackUpdateExistingConversationView(SlackNewConversationView):
slack_conversation: SlackConversation
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
client = WebClient(token=self.bot_access_token)
result = client.conversations_replies(
channel=self.channel_id,
@@ -401,7 +405,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
if not agent_state or agent_state == AgentState.LOADING:
raise StartingConvoException('Conversation is still starting')
instructions, _ = self._get_instructions(jinja)
instructions, _ = await self._get_instructions(jinja)
user_msg = MessageAction(content=instructions)
await conversation_manager.send_event_to_conversation(
self.conversation_id, event_to_dict(user_msg)
@@ -469,7 +473,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
agent_server_url = get_agent_server_url_from_sandbox(running_sandbox)
# 4. Prepare the message content
user_msg, _ = self._get_instructions(jinja)
user_msg, _ = await self._get_instructions(jinja)
# 5. Create the message request
send_message_request = SendMessageRequest(

View File

@@ -42,11 +42,11 @@ async def store_repositories_in_db(repos: list[Repository], user_id: str) -> Non
try:
# Store repositories in the repos table
repo_store = RepositoryStore.get_instance(config)
repo_store.store_projects(stored_repos)
await repo_store.store_projects(stored_repos)
# Store user-repository mappings in the user-repos table
user_repo_store = UserRepositoryMapStore.get_instance(config)
user_repo_store.store_user_repo_mappings(user_repos)
await user_repo_store.store_user_repo_mappings(user_repos)
logger.info(f'Saved repos for user {user_id}')
except Exception:

View File

@@ -3,8 +3,8 @@ from uuid import UUID
import stripe
from server.constants import STRIPE_API_KEY
from server.logger import logger
from sqlalchemy.orm import Session
from storage.database import session_maker
from sqlalchemy import select
from storage.database import a_session_maker
from storage.org import Org
from storage.org_store import OrgStore
from storage.stripe_customer import StripeCustomer
@@ -15,12 +15,10 @@ stripe.api_key = STRIPE_API_KEY
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
with session_maker() as session:
stripe_customer = (
session.query(StripeCustomer)
.filter(StripeCustomer.org_id == org_id)
.first()
)
async with a_session_maker() as session:
stmt = select(StripeCustomer).where(StripeCustomer.org_id == org_id)
result = await session.execute(stmt)
stripe_customer = result.scalar_one_or_none()
if stripe_customer:
return stripe_customer.stripe_customer_id
@@ -74,7 +72,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
)
# Save the stripe customer in the local db
with session_maker() as session:
async with a_session_maker() as session:
session.add(
StripeCustomer(
keycloak_user_id=user_id,
@@ -82,7 +80,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
stripe_customer_id=customer.id,
)
)
session.commit()
await session.commit()
logger.info(
'created_customer',
@@ -108,26 +106,27 @@ async def has_payment_method_by_user_id(user_id: str) -> bool:
return bool(payment_methods.data)
async def migrate_customer(session: Session, user_id: str, org: Org):
stripe_customer = (
session.query(StripeCustomer)
.filter(StripeCustomer.keycloak_user_id == user_id)
.first()
)
if stripe_customer is None:
return
stripe_customer.org_id = org.id
customer = await stripe.Customer.modify_async(
id=stripe_customer.stripe_customer_id,
email=org.contact_email,
metadata={'user_id': '', 'org_id': str(org.id)},
)
async def migrate_customer(user_id: str, org: Org):
async with a_session_maker() as session:
result = await session.execute(
select(StripeCustomer).where(StripeCustomer.keycloak_user_id == user_id)
)
stripe_customer = result.scalar_one_or_none()
if stripe_customer is None:
return
stripe_customer.org_id = org.id
customer = await stripe.Customer.modify_async(
id=stripe_customer.stripe_customer_id,
email=org.contact_email,
metadata={'user_id': '', 'org_id': str(org.id)},
)
logger.info(
'migrated_customer',
extra={
'user_id': user_id,
'org_id': str(org.id),
'stripe_customer_id': customer.id,
},
)
logger.info(
'migrated_customer',
extra={
'user_id': user_id,
'org_id': str(org.id),
'stripe_customer_id': customer.id,
},
)
await session.commit()

View File

@@ -38,7 +38,7 @@ class ResolverViewInterface(SummaryExtractionTracker):
is_public_repo: bool
raw_payload: dict
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"Instructions passed when conversation is first initialized"
raise NotImplementedError()

17
enterprise/poetry.lock generated
View File

@@ -1591,6 +1591,9 @@ files = [
{file = "cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c"},
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a"},
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356"},
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da"},
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257"},
{file = "cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7"},
{file = "cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d"},
]
@@ -6168,23 +6171,23 @@ opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
pathspec = ">=0.12.1"
pexpect = "*"
pg8000 = ">=1.31.5"
pillow = ">=11.3"
pillow = ">=12.1.1"
playwright = ">=1.55"
poetry = ">=2.1.2"
prompt-toolkit = ">=3.0.50"
protobuf = ">=5,<6"
protobuf = ">=5.29.6,<6"
psutil = "*"
pybase62 = ">=1"
pygithub = ">=2.5"
pyjwt = ">=2.9"
pylatexenc = "*"
pypdf = ">=6"
pypdf = ">=6.7.2"
python-docx = "*"
python-dotenv = "*"
python-frontmatter = ">=1.1"
python-jose = {version = ">=3.3", extras = ["cryptography"]}
python-json-logger = ">=3.2.1"
python-multipart = "*"
python-multipart = ">=0.0.22"
python-pptx = "*"
python-socketio = "5.14"
pythonnet = "*"
@@ -6197,7 +6200,7 @@ setuptools = ">=78.1.1"
shellingham = ">=1.5.4"
sqlalchemy = {version = ">=2.0.40", extras = ["asyncio"]}
sse-starlette = ">=3.0.2"
starlette = ">=0.48"
starlette = ">=0.49.1"
tenacity = ">=8.5,<10"
termcolor = "*"
toml = "*"
@@ -11961,7 +11964,7 @@ description = "Python for Window Extensions"
optional = false
python-versions = "*"
groups = ["main"]
markers = "platform_system == \"Windows\" or sys_platform == \"win32\""
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
files = [
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
@@ -14909,4 +14912,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = "^3.12,<3.14"
content-hash = "1cad6029269393af67155e930c72eae2c03da02e4b3a3699823f6168c14a4218"
content-hash = "ef037f6d6085d26166d35c56ce266439f8f1a4fea90bc43ccf15cfeaf116cae5"

View File

@@ -49,7 +49,7 @@ prometheus-client = "^0.24.0"
pandas = "^2.2.0"
numpy = "^2.2.0"
mcp = "^1.10.0"
pillow = "^12.1.0"
pillow = "^12.1.1"
[tool.poetry.group.dev.dependencies]
ruff = "0.8.3"

View File

@@ -48,15 +48,18 @@ from server.routes.orgs import org_router # noqa: E402
from server.routes.readiness import readiness_router # noqa: E402
from server.routes.user import saas_user_router # noqa: E402
from server.routes.user_app_settings import user_app_settings_router # noqa: E402
from server.routes.verified_models import ( # noqa: E402
api_router as verified_models_router,
)
from server.sharing.shared_conversation_router import ( # noqa: E402
router as shared_conversation_router,
)
from server.sharing.shared_event_router import ( # noqa: E402
router as shared_event_router,
)
from server.verified_models.verified_model_router import ( # noqa: E402
api_router as verified_models_router,
)
from server.verified_models.verified_model_router import ( # noqa: E402
override_llm_models_dependency,
)
from openhands.server.app import app as base_app # noqa: E402
from openhands.server.listen_socket import sio # noqa: E402
@@ -113,6 +116,11 @@ base_app.include_router(org_router) # Add routes for organization management
base_app.include_router(
verified_models_router
) # Add routes for verified models management
# Override the default LLM models implementation with SaaS version
# This must happen after all routers are included
override_llm_models_dependency(base_app)
base_app.include_router(invitation_router) # Add routes for org invitation management
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
add_github_proxy_routes(base_app)

View File

@@ -1,5 +1,4 @@
from storage.blocked_email_domain_store import BlockedEmailDomainStore
from storage.database import session_maker
from openhands.core.logger import openhands_logger as logger
@@ -23,7 +22,7 @@ class DomainBlocker:
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
return None
def is_domain_blocked(self, email: str) -> bool:
async def is_domain_blocked(self, email: str) -> bool:
"""Check if email domain is blocked by querying the database directly via SQL.
Supports blocking:
@@ -45,7 +44,7 @@ class DomainBlocker:
try:
# Query database directly via SQL to check if domain is blocked
is_blocked = self.store.is_domain_blocked(domain)
is_blocked = await self.store.is_domain_blocked(domain)
if is_blocked:
logger.warning(f'Email domain {domain} is blocked for email: {email}')
@@ -63,5 +62,5 @@ class DomainBlocker:
# Initialize store and domain blocker
_store = BlockedEmailDomainStore(session_maker=session_maker)
_store = BlockedEmailDomainStore()
domain_blocker = DomainBlocker(store=_store)

View File

@@ -1,7 +1,7 @@
from integrations.github.github_service import SaaSGitHubService
from pydantic import SecretStr
from server.auth.auth_utils import user_verifier
from enterprise.server.auth.auth_utils import user_verifier
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_types import GitHubUser

View File

@@ -18,9 +18,10 @@ from server.auth.token_manager import TokenManager
from server.config import get_config
from server.logger import logger
from server.rate_limit import RateLimiter, create_redis_rate_limiter
from sqlalchemy import delete, select
from storage.api_key_store import ApiKeyStore
from storage.auth_tokens import AuthTokens
from storage.database import session_maker
from storage.database import a_session_maker
from storage.saas_secrets_store import SaasSecretsStore
from storage.saas_settings_store import SaasSettingsStore
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
@@ -124,7 +125,7 @@ class SaasUserAuth(UserAuth):
if secrets_store:
return secrets_store
user_id = await self.get_user_id()
secrets_store = SaasSecretsStore(user_id, session_maker, get_config())
secrets_store = SaasSecretsStore(user_id, get_config())
self.secrets_store = secrets_store
return secrets_store
@@ -161,12 +162,13 @@ class SaasUserAuth(UserAuth):
try:
# TODO: I think we can do this in a single request if we refactor
with session_maker() as session:
tokens = (
session.query(AuthTokens)
.where(AuthTokens.keycloak_user_id == self.user_id)
.all()
async with a_session_maker() as session:
result = await session.execute(
select(AuthTokens).where(
AuthTokens.keycloak_user_id == self.user_id
)
)
tokens = result.scalars().all()
for token in tokens:
idp_type = ProviderType(token.identity_provider)
@@ -192,11 +194,11 @@ class SaasUserAuth(UserAuth):
'idp_type': token.identity_provider,
},
)
with session_maker() as session:
session.query(AuthTokens).filter(
AuthTokens.id == token.id
).delete()
session.commit()
async with a_session_maker() as session:
await session.execute(
delete(AuthTokens).where(AuthTokens.id == token.id)
)
await session.commit()
raise
self.provider_tokens = MappingProxyType(provider_tokens)
@@ -210,7 +212,7 @@ class SaasUserAuth(UserAuth):
if settings_store:
return settings_store
user_id = await self.get_user_id()
settings_store = SaasSettingsStore(user_id, session_maker, get_config())
settings_store = SaasSettingsStore(user_id, get_config())
self.settings_store = settings_store
return settings_store
@@ -278,7 +280,7 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
return None
api_key_store = ApiKeyStore.get_instance()
user_id = api_key_store.validate_api_key(api_key)
user_id = await api_key_store.validate_api_key(api_key)
if not user_id:
return None
offline_token = await token_manager.load_offline_token(user_id)
@@ -327,7 +329,7 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
email_verified = access_token_payload['email_verified']
# Check if email domain is blocked
if email and domain_blocker.is_domain_blocked(email):
if email and await domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for existing user with email: {email}'
)

View File

@@ -38,9 +38,9 @@ from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
from server.config import get_config
from server.logger import logger
from sqlalchemy import String as SQLString
from sqlalchemy import type_coerce
from sqlalchemy import select, type_coerce
from storage.auth_token_store import AuthTokenStore
from storage.database import session_maker
from storage.database import a_session_maker
from storage.github_app_installation import GithubAppInstallation
from storage.offline_token_store import OfflineTokenStore
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
@@ -783,25 +783,24 @@ class TokenManager:
exc_info=True,
)
def store_org_token(self, installation_id: int, installation_token: str):
async def store_org_token(self, installation_id: int, installation_token: str):
"""Store a GitHub App installation token.
Args:
installation_id: GitHub installation ID (integer or string)
installation_token: The token to store
"""
with session_maker() as session:
async with a_session_maker() as session:
# Ensure installation_id is a string
str_installation_id = str(installation_id)
# Use type_coerce to ensure SQLAlchemy treats the parameter as a string
installation = (
session.query(GithubAppInstallation)
.filter(
result = await session.execute(
select(GithubAppInstallation).filter(
GithubAppInstallation.installation_id
== type_coerce(str_installation_id, SQLString)
)
.first()
)
installation = result.scalars().first()
if installation:
installation.encrypted_token = self.encrypt_text(installation_token)
else:
@@ -811,9 +810,9 @@ class TokenManager:
encrypted_token=self.encrypt_text(installation_token),
)
)
session.commit()
await session.commit()
def load_org_token(self, installation_id: int) -> str | None:
async def load_org_token(self, installation_id: int) -> str | None:
"""Load a GitHub App installation token.
Args:
@@ -822,17 +821,16 @@ class TokenManager:
Returns:
The decrypted token if found, None otherwise
"""
with session_maker() as session:
async with a_session_maker() as session:
# Ensure installation_id is a string and use type_coerce
str_installation_id = str(installation_id)
installation = (
session.query(GithubAppInstallation)
.filter(
result = await session.execute(
select(GithubAppInstallation).filter(
GithubAppInstallation.installation_id
== type_coerce(str_installation_id, SQLString)
)
.first()
)
installation = result.scalars().first()
if not installation:
return None
token = self.decrypt_text(installation.encrypted_token)

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from integrations.github.github_manager import GithubManager
from integrations.github.github_view import GithubViewType
from integrations.models import Message, SourceType
from integrations.utils import (
extract_summary_from_conversation_manager,
get_summary_instruction,
@@ -35,16 +34,12 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
send_summary_instruction: bool = True
async def _send_message_to_github(self, message: str) -> None:
"""
Send a message to GitHub.
"""Send a message to GitHub.
Args:
message: The message content to send to GitHub
"""
try:
# Create a message object for GitHub
message_obj = Message(source=SourceType.OPENHANDS, message=message)
# Get the token manager
token_manager = TokenManager()
@@ -53,8 +48,8 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
github_manager = GithubManager(token_manager, GitHubDataCollector())
# Send the message
await github_manager.send_message(message_obj, self.github_view)
# Send the message directly as a string
await github_manager.send_message(message, self.github_view)
logger.info(
f'[GitHub] Sent summary message to {self.github_view.full_repo_name}#{self.github_view.issue_number}'

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from integrations.gitlab.gitlab_manager import GitlabManager
from integrations.gitlab.gitlab_view import GitlabViewType
from integrations.models import Message, SourceType
from integrations.utils import (
extract_summary_from_conversation_manager,
get_summary_instruction,
@@ -14,7 +13,7 @@ from storage.conversation_callback import (
ConversationCallback,
ConversationCallbackProcessor,
)
from storage.database import session_maker
from storage.database import a_session_maker
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
@@ -28,8 +27,7 @@ gitlab_manager = GitlabManager(token_manager)
class GitlabCallbackProcessor(ConversationCallbackProcessor):
"""
Processor for sending conversation summaries to GitLab.
"""Processor for sending conversation summaries to GitLab.
This processor is used to send summaries of conversations to GitLab
when agent state changes occur.
@@ -39,22 +37,18 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
send_summary_instruction: bool = True
async def _send_message_to_gitlab(self, message: str) -> None:
"""
Send a message to GitLab.
"""Send a message to GitLab.
Args:
message: The message content to send to GitLab
"""
try:
# Create a message object for GitHub
message_obj = Message(source=SourceType.OPENHANDS, message=message)
# Get the token manager
token_manager = TokenManager()
gitlab_manager = GitlabManager(token_manager)
# Send the message
await gitlab_manager.send_message(message_obj, self.gitlab_view)
# Send the message directly as a string
await gitlab_manager.send_message(message, self.gitlab_view)
logger.info(
f'[GitLab] Sent summary message to {self.gitlab_view.full_repo_name}#{self.gitlab_view.issue_number}'
@@ -111,9 +105,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
self.send_summary_instruction = False
callback.set_processor(self)
callback.updated_at = datetime.now()
with session_maker() as session:
async with a_session_maker() as session:
session.merge(callback)
session.commit()
await session.commit()
return
# Extract the summary from the event store
@@ -132,9 +126,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
# Mark callback as completed status
callback.status = CallbackStatus.COMPLETED
callback.updated_at = datetime.now()
with session_maker() as session:
async with a_session_maker() as session:
session.merge(callback)
session.commit()
await session.commit()
except Exception as e:
logger.exception(

View File

@@ -37,8 +37,7 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
workspace_name: str
async def _send_comment_to_jira(self, message: str) -> None:
"""
Send a comment to Jira issue.
"""Send a comment to Jira issue.
Args:
message: The message content to send to Jira
@@ -59,8 +58,9 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
# Decrypt API key
api_key = jira_manager.token_manager.decrypt_text(workspace.svc_acc_api_key)
# Send comment directly as a string
await jira_manager.send_message(
jira_manager.create_outgoing_message(msg=message),
message,
issue_key=self.issue_key,
jira_cloud_id=workspace.jira_cloud_id,
svc_acc_email=workspace.svc_acc_email,

View File

@@ -37,8 +37,7 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
base_api_url: str
async def _send_comment_to_jira_dc(self, message: str) -> None:
"""
Send a comment to Jira DC issue.
"""Send a comment to Jira DC issue.
Args:
message: The message content to send to Jira DC
@@ -61,8 +60,9 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
workspace.svc_acc_api_key
)
# Send comment directly as a string
await jira_dc_manager.send_message(
jira_dc_manager.create_outgoing_message(msg=message),
message,
issue_key=self.issue_key,
base_api_url=self.base_api_url,
svc_acc_api_key=api_key,

View File

@@ -36,8 +36,7 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
workspace_name: str
async def _send_comment_to_linear(self, message: str) -> None:
"""
Send a comment to Linear issue.
"""Send a comment to Linear issue.
Args:
message: The message content to send to Linear
@@ -60,9 +59,9 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
workspace.svc_acc_api_key
)
# Send comment
# Send comment directly as a string
await linear_manager.send_message(
linear_manager.create_outgoing_message(msg=message),
message,
self.issue_id,
api_key,
)

View File

@@ -26,8 +26,7 @@ slack_manager = SlackManager(token_manager)
class SlackCallbackProcessor(ConversationCallbackProcessor):
"""
Processor for sending conversation summaries to Slack.
"""Processor for sending conversation summaries to Slack.
This processor is used to send summaries of conversations to Slack channels
when agent state changes occur.
@@ -41,14 +40,13 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
last_user_msg_id: int | None = None
async def _send_message_to_slack(self, message: str) -> None:
"""
Send a message to Slack using the conversation_manager's send_to_event_stream method.
"""Send a message to Slack.
Args:
message: The message content to send to Slack
"""
try:
# Create a message object for Slack
# Create a message object for Slack view creation (incoming message format)
message_obj = Message(
source=SourceType.SLACK,
message={
@@ -67,9 +65,8 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
slack_view = SlackFactory.create_slack_view_from_payload(
message_obj, slack_user, saas_user_auth
)
await slack_manager.send_message(
slack_manager.create_outgoing_message(message), slack_view
)
# Send the message directly as a string
await slack_manager.send_message(message, slack_view)
logger.info(
f'[Slack] Sent summary message to channel {self.channel_id} '

View File

@@ -251,7 +251,7 @@ async def delete_api_key(
)
# Delete the key
success = api_key_store.delete_api_key_by_id(key_id)
success = await api_key_store.delete_api_key_by_id(key_id)
if not success:
raise HTTPException(

View File

@@ -34,7 +34,8 @@ from server.services.org_invitation_service import (
OrgInvitationService,
UserAlreadyMemberError,
)
from storage.database import session_maker
from sqlalchemy import select
from storage.database import a_session_maker
from storage.user import User
from storage.user_store import UserStore
@@ -270,7 +271,7 @@ async def keycloak_callback(
# Fail open - continue with login if reCAPTCHA service unavailable
# Check if email domain is blocked
if email and domain_blocker.is_domain_blocked(email):
if email and await domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
)
@@ -610,17 +611,20 @@ async def accept_tos(request: Request):
# Update user settings with TOS acceptance
accepted_tos: datetime = datetime.now(timezone.utc)
with session_maker() as session:
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
async with a_session_maker() as session:
result = await session.execute(
select(User).where(User.id == uuid.UUID(user_id))
)
user = result.scalar_one_or_none()
if not user:
session.rollback()
await session.rollback()
logger.error('User for {user_id} not found.')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'User does not exist'},
)
user.accepted_tos = accepted_tos
session.commit()
await session.commit()
logger.info(f'User {user_id} accepted TOS')

View File

@@ -11,9 +11,10 @@ from integrations import stripe_service
from pydantic import BaseModel
from server.constants import STRIPE_API_KEY
from server.logger import logger
from sqlalchemy import select
from starlette.datastructures import URL
from storage.billing_session import BillingSession
from storage.database import session_maker
from storage.database import a_session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
from storage.subscription_access import SubscriptionAccess
@@ -106,16 +107,17 @@ async def get_subscription_access(
user_id: str = Depends(get_user_id),
) -> SubscriptionAccessResponse | None:
"""Get details of the currently valid subscription for the user."""
with session_maker() as session:
async with a_session_maker() as session:
now = datetime.now(UTC)
subscription_access = (
session.query(SubscriptionAccess)
.filter(SubscriptionAccess.status == 'ACTIVE')
.filter(SubscriptionAccess.user_id == user_id)
.filter(SubscriptionAccess.start_at <= now)
.filter(SubscriptionAccess.end_at >= now)
.first()
result = await session.execute(
select(SubscriptionAccess).where(
SubscriptionAccess.status == 'ACTIVE',
SubscriptionAccess.user_id == user_id,
SubscriptionAccess.start_at <= now,
SubscriptionAccess.end_at >= now,
)
)
subscription_access = result.scalar_one_or_none()
if not subscription_access:
return None
return SubscriptionAccessResponse(
@@ -197,7 +199,7 @@ async def create_checkout_session(
'checkout_session_id': checkout_session.id,
},
)
with session_maker() as session:
async with a_session_maker() as session:
billing_session = BillingSession(
id=checkout_session.id,
user_id=user_id,
@@ -206,7 +208,7 @@ async def create_checkout_session(
price_code='NA',
)
session.add(billing_session)
session.commit()
await session.commit()
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
@@ -215,13 +217,14 @@ async def create_checkout_session(
@billing_router.get('/success')
async def success_callback(session_id: str, request: Request):
# We can't use the auth cookie because of SameSite=strict
with session_maker() as session:
billing_session = (
session.query(BillingSession)
.filter(BillingSession.id == session_id)
.filter(BillingSession.status == 'in_progress')
.first()
async with a_session_maker() as session:
result = await session.execute(
select(BillingSession).where(
BillingSession.id == session_id,
BillingSession.status == 'in_progress',
)
)
billing_session = result.scalar_one_or_none()
if billing_session is None:
# Hopefully this never happens - we get a redirect from stripe where the session does not exist
@@ -253,7 +256,8 @@ async def success_callback(session_id: str, request: Request):
user_team_info, billing_session.user_id, str(user.current_org_id)
)
org = session.query(Org).filter(Org.id == user.current_org_id).first()
result = await session.execute(select(Org).where(Org.id == user.current_org_id))
org = result.scalar_one_or_none()
new_max_budget = max_budget + add_credits
await LiteLlmManager.update_team_and_users_budget(
@@ -279,7 +283,7 @@ async def success_callback(session_id: str, request: Request):
'stripe_customer_id': stripe_session.customer,
},
)
session.commit()
await session.commit()
return RedirectResponse(
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
@@ -289,13 +293,14 @@ async def success_callback(session_id: str, request: Request):
# Callback endpoint for cancelled Stripe payments - updates billing session status
@billing_router.get('/cancel')
async def cancel_callback(session_id: str, request: Request):
with session_maker() as session:
billing_session = (
session.query(BillingSession)
.filter(BillingSession.id == session_id)
.filter(BillingSession.status == 'in_progress')
.first()
async with a_session_maker() as session:
result = await session.execute(
select(BillingSession).where(
BillingSession.id == session_id,
BillingSession.status == 'in_progress',
)
)
billing_session = result.scalar_one_or_none()
if billing_session:
logger.info(
'stripe_checkout_cancel',
@@ -307,7 +312,7 @@ async def cancel_callback(session_id: str, request: Request):
billing_session.status = 'cancelled'
billing_session.updated_at = datetime.now(UTC)
session.merge(billing_session)
session.commit()
await session.commit()
return RedirectResponse(
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.future import select
from storage.database import session_maker
from storage.database import a_session_maker
from storage.feedback import ConversationFeedback
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
@@ -11,7 +11,6 @@ from openhands.events.event_store import EventStore
from openhands.server.dependencies import get_dependencies
from openhands.server.shared import file_store
from openhands.server.user_auth import get_user_id
from openhands.utils.async_utils import call_sync_from_async
# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint
# is protected. The actual protection is provided by SetAuthCookieMiddleware
@@ -37,23 +36,19 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
"""
# Verify the conversation belongs to the user
def _verify_conversation():
with session_maker() as session:
metadata = (
session.query(StoredConversationMetadataSaas)
.filter(
StoredConversationMetadataSaas.conversation_id == conversation_id,
StoredConversationMetadataSaas.user_id == user_id,
)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(StoredConversationMetadataSaas).where(
StoredConversationMetadataSaas.conversation_id == conversation_id,
StoredConversationMetadataSaas.user_id == user_id,
)
)
metadata = result.scalars().first()
if not metadata:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Conversation {conversation_id} not found',
)
if not metadata:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Conversation {conversation_id} not found',
)
await call_sync_from_async(_verify_conversation)
# Create an event store to access the events directly
# This works even when the conversation is not running
@@ -103,12 +98,9 @@ async def submit_conversation_feedback(feedback: FeedbackRequest):
)
# Add to database
def _save_feedback():
with session_maker() as session:
session.add(new_feedback)
session.commit()
await call_sync_from_async(_save_feedback)
async with a_session_maker() as session:
session.add(new_feedback)
await session.commit()
return {'status': 'success', 'message': 'Feedback submitted successfully'}
@@ -127,30 +119,27 @@ async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_us
return {}
# Query for existing feedback for all events
def _check_feedback():
with session_maker() as session:
result = session.execute(
select(ConversationFeedback).where(
ConversationFeedback.conversation_id == conversation_id,
ConversationFeedback.event_id.in_(event_ids),
)
async with a_session_maker() as session:
result = await session.execute(
select(ConversationFeedback).where(
ConversationFeedback.conversation_id == conversation_id,
ConversationFeedback.event_id.in_(event_ids),
)
)
# Create a mapping of event_id to feedback
feedback_map = {
feedback.event_id: {
'exists': True,
'rating': feedback.rating,
'reason': feedback.reason,
}
for feedback in result.scalars()
# Create a mapping of event_id to feedback
feedback_map = {
feedback.event_id: {
'exists': True,
'rating': feedback.rating,
'reason': feedback.reason,
}
for feedback in result.scalars()
}
# Build response including all events
response = {}
for event_id in event_ids:
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
# Build response including all events
response = {}
for event_id in event_ids:
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
return response
return await call_sync_from_async(_check_feedback)
return response

View File

@@ -308,10 +308,11 @@ async def jira_events(
logger.info(f'Processing new Jira webhook event: {signature}')
redis_client.setex(key, 300, '1')
# Process the webhook
# Process the webhook in background after returning response.
# Note: For async functions, BackgroundTasks runs them in the same event loop
# (not a thread pool), so asyncpg connections work correctly.
message_payload = {'payload': payload}
message = Message(source=SourceType.JIRA, message=message_payload)
background_tasks.add_task(jira_manager.receive_message, message)
return JSONResponse({'success': True})

View File

@@ -7,7 +7,6 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from storage.api_key_store import ApiKeyStore
from storage.database import session_maker
from storage.device_code_store import DeviceCodeStore
from openhands.core.logger import openhands_logger as logger
@@ -54,7 +53,7 @@ class DeviceTokenErrorResponse(BaseModel):
# ---------------------------------------------------------------------------
oauth_device_router = APIRouter(prefix='/oauth/device')
device_code_store = DeviceCodeStore(session_maker)
device_code_store = DeviceCodeStore()
# ---------------------------------------------------------------------------
@@ -90,7 +89,7 @@ async def device_authorization(
) -> DeviceAuthorizationResponse:
"""Start device flow by generating device and user codes."""
try:
device_code_entry = device_code_store.create_device_code(
device_code_entry = await device_code_store.create_device_code(
expires_in=DEVICE_CODE_EXPIRES_IN,
)
@@ -125,7 +124,7 @@ async def device_authorization(
async def device_token(device_code: str = Form(...)):
"""Poll for a token until the user authorizes or the code expires."""
try:
device_code_entry = device_code_store.get_by_device_code(device_code)
device_code_entry = await device_code_store.get_by_device_code(device_code)
if not device_code_entry:
return _oauth_error(
@@ -138,7 +137,9 @@ async def device_token(device_code: str = Form(...)):
is_too_fast, current_interval = device_code_entry.check_rate_limit()
if is_too_fast:
# Update poll time and increase interval
device_code_store.update_poll_time(device_code, increase_interval=True)
await device_code_store.update_poll_time(
device_code, increase_interval=True
)
logger.warning(
'Client polling too fast, returning slow_down error',
extra={
@@ -154,7 +155,7 @@ async def device_token(device_code: str = Form(...)):
)
# Update poll time for successful rate limit check
device_code_store.update_poll_time(device_code, increase_interval=False)
await device_code_store.update_poll_time(device_code, increase_interval=False)
if device_code_entry.is_expired():
return _oauth_error(
@@ -181,7 +182,7 @@ async def device_token(device_code: str = Form(...)):
# Retrieve the specific API key for this device using the user_code
api_key_store = ApiKeyStore.get_instance()
device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})'
device_api_key = api_key_store.retrieve_api_key_by_name(
device_api_key = await api_key_store.retrieve_api_key_by_name(
device_code_entry.keycloak_user_id, device_key_name
)
@@ -238,7 +239,7 @@ async def device_verification_authenticated(
)
# Validate device code
device_code_entry = device_code_store.get_by_user_code(user_code)
device_code_entry = await device_code_store.get_by_user_code(user_code)
if not device_code_entry:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -252,7 +253,7 @@ async def device_verification_authenticated(
)
# First, authorize the device code
success = device_code_store.authorize_device_code(
success = await device_code_store.authorize_device_code(
user_code=user_code,
user_id=user_id,
)
@@ -289,7 +290,7 @@ async def device_verification_authenticated(
# Clean up: revert the device authorization since API key creation failed
# This prevents the device from being in an authorized state without an API key
try:
device_code_store.deny_device_code(user_code)
await device_code_store.deny_device_code(user_code)
logger.info(
'Reverted device authorization due to API key creation failure',
extra={'user_code': user_code, 'user_id': user_id},

View File

@@ -1,6 +1,13 @@
from typing import Annotated
from pydantic import BaseModel, EmailStr, Field, SecretStr, StringConstraints
from pydantic import (
BaseModel,
EmailStr,
Field,
SecretStr,
StringConstraints,
field_validator,
)
from storage.org import Org
from storage.org_member import OrgMember
from storage.role import Role
@@ -252,6 +259,115 @@ class OrgUpdate(BaseModel):
condenser_max_size: int | None = Field(default=None, ge=20)
class OrgLLMSettingsResponse(BaseModel):
"""Response model for organization LLM settings."""
default_llm_model: str | None = None
default_llm_base_url: str | None = None
search_api_key: str | None = None # Masked in response
agent: str | None = None
confirmation_mode: bool | None = None
security_analyzer: str | None = None
enable_default_condenser: bool = True
condenser_max_size: int | None = None
default_max_iterations: int | None = None
@staticmethod
def _mask_key(secret: SecretStr | None) -> str | None:
"""Mask an API key, showing only last 4 characters."""
if secret is None:
return None
raw = secret.get_secret_value()
if not raw:
return None
if len(raw) <= 4:
return '****'
return '****' + raw[-4:]
@classmethod
def from_org(cls, org: Org) -> 'OrgLLMSettingsResponse':
"""Create response from Org entity."""
return cls(
default_llm_model=org.default_llm_model,
default_llm_base_url=org.default_llm_base_url,
search_api_key=cls._mask_key(org.search_api_key),
agent=org.agent,
confirmation_mode=org.confirmation_mode,
security_analyzer=org.security_analyzer,
enable_default_condenser=org.enable_default_condenser
if org.enable_default_condenser is not None
else True,
condenser_max_size=org.condenser_max_size,
default_max_iterations=org.default_max_iterations,
)
class OrgMemberLLMSettings(BaseModel):
"""LLM settings to propagate to organization members.
Field names match OrgMember DB columns.
"""
llm_model: str | None = None
llm_base_url: str | None = None
max_iterations: int | None = None
llm_api_key: str | None = None
def has_updates(self) -> bool:
"""Check if any field is set (not None)."""
return any(getattr(self, field) is not None for field in self.model_fields)
class OrgLLMSettingsUpdate(BaseModel):
"""Request model for updating organization LLM settings.
Field names match Org DB columns exactly.
"""
default_llm_model: str | None = None
default_llm_base_url: str | None = None
search_api_key: str | None = None
agent: str | None = None
confirmation_mode: bool | None = None
security_analyzer: str | None = None
enable_default_condenser: bool | None = None
condenser_max_size: int | None = Field(default=None, ge=20)
default_max_iterations: int | None = Field(default=None, gt=0)
llm_api_key: str | None = None
def has_updates(self) -> bool:
"""Check if any field is set (not None)."""
return any(getattr(self, field) is not None for field in self.model_fields)
def apply_to_org(self, org: Org) -> None:
"""Apply non-None settings to the organization model.
Args:
org: Organization entity to update in place
"""
for field_name in self.model_fields:
value = getattr(self, field_name)
# Skip llm_api_key - it's only for member propagation, not org-level
if value is not None and field_name != 'llm_api_key':
setattr(org, field_name, value)
def get_member_updates(self) -> OrgMemberLLMSettings | None:
"""Get updates that need to be propagated to org members.
Returns:
OrgMemberLLMSettings with mapped field values, or None if no member updates needed.
Maps: default_llm_model → llm_model, default_llm_base_url → llm_base_url,
default_max_iterations → max_iterations, llm_api_key → llm_api_key
"""
member_settings = OrgMemberLLMSettings(
llm_model=self.default_llm_model,
llm_base_url=self.default_llm_base_url,
max_iterations=self.default_max_iterations,
llm_api_key=self.llm_api_key,
)
return member_settings if member_settings.has_updates() else None
class OrgMemberResponse(BaseModel):
"""Response model for a single organization member."""
@@ -327,3 +443,44 @@ class MeResponse(BaseModel):
llm_base_url=member.llm_base_url,
status=member.status,
)
class OrgAppSettingsResponse(BaseModel):
"""Response model for organization app settings."""
enable_proactive_conversation_starters: bool = True
enable_solvability_analysis: bool | None = None
max_budget_per_task: float | None = None
@classmethod
def from_org(cls, org: Org) -> 'OrgAppSettingsResponse':
"""Create an OrgAppSettingsResponse from an Org entity.
Args:
org: The organization entity
Returns:
OrgAppSettingsResponse with app settings
"""
return cls(
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters
if org.enable_proactive_conversation_starters is not None
else True,
enable_solvability_analysis=org.enable_solvability_analysis,
max_budget_per_task=org.max_budget_per_task,
)
class OrgAppSettingsUpdate(BaseModel):
"""Request model for updating organization app settings."""
enable_proactive_conversation_starters: bool | None = None
enable_solvability_analysis: bool | None = None
max_budget_per_task: float | None = None
@field_validator('max_budget_per_task')
@classmethod
def validate_max_budget_per_task(cls, v: float | None) -> float | None:
if v is not None and v <= 0:
raise ValueError('max_budget_per_task must be greater than 0')
return v

View File

@@ -15,9 +15,13 @@ from server.routes.org_models import (
LiteLLMIntegrationError,
MemberUpdateError,
MeResponse,
OrgAppSettingsResponse,
OrgAppSettingsUpdate,
OrgAuthorizationError,
OrgCreate,
OrgDatabaseError,
OrgLLMSettingsResponse,
OrgLLMSettingsUpdate,
OrgMemberNotFoundError,
OrgMemberPage,
OrgMemberResponse,
@@ -30,6 +34,14 @@ from server.routes.org_models import (
OrphanedUserError,
RoleNotFoundError,
)
from server.services.org_app_settings_service import (
OrgAppSettingsService,
OrgAppSettingsServiceInjector,
)
from server.services.org_llm_settings_service import (
OrgLLMSettingsService,
OrgLLMSettingsServiceInjector,
)
from server.services.org_member_service import OrgMemberService
from storage.org_service import OrgService
from storage.user_store import UserStore
@@ -40,6 +52,13 @@ from openhands.server.user_auth import get_user_id
# Initialize API router
org_router = APIRouter(prefix='/api/organizations', tags=['Orgs'])
# Create injector instance and dependency for LLM settings
_org_llm_settings_injector = OrgLLMSettingsServiceInjector()
org_llm_settings_service_dependency = Depends(_org_llm_settings_injector.depends)
# Create injector instance and dependency at module level
_org_app_settings_injector = OrgAppSettingsServiceInjector()
org_app_settings_service_dependency = Depends(_org_app_settings_injector.depends)
@org_router.get('', response_model=OrgPage)
async def list_user_orgs(
@@ -201,6 +220,195 @@ async def create_org(
)
@org_router.get(
'/llm',
response_model=OrgLLMSettingsResponse,
dependencies=[Depends(require_permission(Permission.VIEW_LLM_SETTINGS))],
)
async def get_org_llm_settings(
service: OrgLLMSettingsService = org_llm_settings_service_dependency,
) -> OrgLLMSettingsResponse:
"""Get LLM settings for the user's current organization.
This endpoint retrieves the LLM configuration settings for the
authenticated user's current organization. All organization members
can view these settings.
Args:
service: OrgLLMSettingsService (injected by dependency)
Returns:
OrgLLMSettingsResponse: The organization's LLM settings
Raises:
HTTPException: 401 if not authenticated
HTTPException: 403 if not a member of any organization
HTTPException: 404 if current organization not found
HTTPException: 500 if retrieval fails
"""
try:
return await service.get_org_llm_settings()
except OrgNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except Exception as e:
logger.exception(
'Error getting organization LLM settings',
extra={'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to retrieve LLM settings',
)
@org_router.post(
'/llm',
response_model=OrgLLMSettingsResponse,
dependencies=[Depends(require_permission(Permission.EDIT_LLM_SETTINGS))],
)
async def update_org_llm_settings(
settings: OrgLLMSettingsUpdate,
service: OrgLLMSettingsService = org_llm_settings_service_dependency,
) -> OrgLLMSettingsResponse:
"""Update LLM settings for the user's current organization.
This endpoint updates the LLM configuration settings for the
authenticated user's current organization. Only admins and owners
can update these settings.
Args:
settings: The LLM settings to update (only non-None fields are updated)
service: OrgLLMSettingsService (injected by dependency)
Returns:
OrgLLMSettingsResponse: The updated organization's LLM settings
Raises:
HTTPException: 401 if not authenticated
HTTPException: 403 if user lacks EDIT_LLM_SETTINGS permission
HTTPException: 404 if current organization not found
HTTPException: 500 if update fails
"""
try:
return await service.update_org_llm_settings(settings)
except OrgNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except OrgDatabaseError as e:
logger.error(
'Database error updating LLM settings',
extra={'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to update LLM settings',
)
except Exception as e:
logger.exception(
'Error updating organization LLM settings',
extra={'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to update LLM settings',
)
@org_router.get(
'/app',
response_model=OrgAppSettingsResponse,
dependencies=[Depends(require_permission(Permission.MANAGE_APPLICATION_SETTINGS))],
)
async def get_org_app_settings(
service: OrgAppSettingsService = org_app_settings_service_dependency,
) -> OrgAppSettingsResponse:
"""Get organization app settings for the user's current organization.
This endpoint retrieves application settings for the authenticated user's
current organization. Access requires the MANAGE_APPLICATION_SETTINGS permission,
which is granted to all organization members (member, admin, and owner roles).
Args:
service: OrgAppSettingsService (injected by dependency)
Returns:
OrgAppSettingsResponse: The organization app settings
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks MANAGE_APPLICATION_SETTINGS permission
HTTPException: 404 if current organization not found
"""
try:
return await service.get_org_app_settings()
except OrgNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='Current organization not found',
)
except Exception as e:
logger.exception(
'Unexpected error retrieving organization app settings',
extra={'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.post(
'/app',
response_model=OrgAppSettingsResponse,
dependencies=[Depends(require_permission(Permission.MANAGE_APPLICATION_SETTINGS))],
)
async def update_org_app_settings(
update_data: OrgAppSettingsUpdate,
service: OrgAppSettingsService = org_app_settings_service_dependency,
) -> OrgAppSettingsResponse:
"""Update organization app settings for the user's current organization.
This endpoint updates application settings for the authenticated user's
current organization. Access requires the MANAGE_APPLICATION_SETTINGS permission,
which is granted to all organization members (member, admin, and owner roles).
Args:
update_data: App settings update data
service: OrgAppSettingsService (injected by dependency)
Returns:
OrgAppSettingsResponse: The updated organization app settings
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks MANAGE_APPLICATION_SETTINGS permission
HTTPException: 404 if current organization not found
HTTPException: 422 if validation errors occur (handled by FastAPI)
HTTPException: 500 if update fails
"""
try:
return await service.update_org_app_settings(update_data)
except OrgNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='Current organization not found',
)
except Exception as e:
logger.exception(
'Unexpected error updating organization app settings',
extra={'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
async def get_org(
org_id: UUID,

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, HTTPException, status
from sqlalchemy.sql import text
from storage.database import session_maker
from storage.database import a_session_maker
from storage.redis import create_redis_client
from openhands.core.logger import openhands_logger as logger
@@ -9,11 +9,11 @@ readiness_router = APIRouter()
@readiness_router.get('/ready')
def is_ready():
async def is_ready():
# Check database connection
try:
with session_maker() as session:
session.execute(text('SELECT 1'))
async with a_session_maker() as session:
await session.execute(text('SELECT 1'))
except Exception as e:
logger.error(f'Database check failed: {str(e)}')
raise HTTPException(

View File

@@ -388,5 +388,4 @@ async def _check_idp(
access_token.get_secret_value(), ProviderType(idp)
):
return default_value
return None

View File

@@ -1,184 +0,0 @@
"""API routes for managing verified LLM models (admin only)."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, field_validator
from server.email_validation import get_admin_user_id
from storage.verified_model_store import VerifiedModelStore
from openhands.core.logger import openhands_logger as logger
api_router = APIRouter(prefix='/api/admin/verified-models', tags=['Verified Models'])
class VerifiedModelCreate(BaseModel):
model_name: str
provider: str
is_enabled: bool = True
@field_validator('model_name')
@classmethod
def validate_model_name(cls, v: str) -> str:
v = v.strip()
if not v or len(v) > 255:
raise ValueError('model_name must be 1-255 characters')
return v
@field_validator('provider')
@classmethod
def validate_provider(cls, v: str) -> str:
v = v.strip()
if not v or len(v) > 100:
raise ValueError('provider must be 1-100 characters')
return v
class VerifiedModelUpdate(BaseModel):
is_enabled: bool | None = None
class VerifiedModelResponse(BaseModel):
id: int
model_name: str
provider: str
is_enabled: bool
class VerifiedModelPage(BaseModel):
"""Paginated response model for verified model list."""
items: list[VerifiedModelResponse]
next_page_id: str | None = None
def _to_response(model) -> VerifiedModelResponse:
return VerifiedModelResponse(
id=model.id,
model_name=model.model_name,
provider=model.provider,
is_enabled=model.is_enabled,
)
@api_router.get('', response_model=VerifiedModelPage)
async def list_verified_models(
provider: str | None = None,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int, Query(title='The max number of results in the page', gt=0, le=100)
] = 100,
user_id: str = Depends(get_admin_user_id),
):
"""List all verified models, optionally filtered by provider."""
try:
if provider:
all_models = VerifiedModelStore.get_models_by_provider(provider)
else:
all_models = VerifiedModelStore.get_all_models()
try:
offset = int(page_id) if page_id else 0
except ValueError:
offset = 0
page = all_models[offset : offset + limit + 1]
has_more = len(page) > limit
if has_more:
page = page[:limit]
return VerifiedModelPage(
items=[_to_response(m) for m in page],
next_page_id=str(offset + limit) if has_more else None,
)
except Exception:
logger.exception('Error listing verified models')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to list verified models',
)
@api_router.post('', response_model=VerifiedModelResponse, status_code=201)
async def create_verified_model(
data: VerifiedModelCreate,
user_id: str = Depends(get_admin_user_id),
):
"""Create a new verified model."""
try:
model = VerifiedModelStore.create_model(
model_name=data.model_name,
provider=data.provider,
is_enabled=data.is_enabled,
)
return _to_response(model)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e),
)
except Exception:
logger.exception('Error creating verified model')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to create verified model',
)
@api_router.put('/{provider}/{model_name:path}', response_model=VerifiedModelResponse)
async def update_verified_model(
provider: str,
model_name: str,
data: VerifiedModelUpdate,
user_id: str = Depends(get_admin_user_id),
):
"""Update a verified model by provider and model name."""
try:
model = VerifiedModelStore.update_model(
model_name=model_name,
provider=provider,
is_enabled=data.is_enabled,
)
if not model:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Model {provider}/{model_name} not found',
)
return _to_response(model)
except HTTPException:
raise
except Exception:
logger.exception(f'Error updating verified model: {provider}/{model_name}')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to update verified model',
)
@api_router.delete('/{provider}/{model_name:path}')
async def delete_verified_model(
provider: str,
model_name: str,
user_id: str = Depends(get_admin_user_id),
):
"""Delete a verified model by provider and model name."""
try:
success = VerifiedModelStore.delete_model(
model_name=model_name, provider=provider
)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Model {provider}/{model_name} not found',
)
return {'message': f'Model {provider}/{model_name} deleted'}
except HTTPException:
raise
except Exception:
logger.exception(f'Error deleting verified model: {provider}/{model_name}')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to delete verified model',
)

View File

@@ -0,0 +1,130 @@
"""Service class for managing organization app settings.
Separates business logic from route handlers.
Uses dependency injection for db_session and user_context.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import AsyncGenerator
from fastapi import Request
from server.routes.org_models import (
OrgAppSettingsResponse,
OrgAppSettingsUpdate,
OrgNotFoundError,
)
from storage.org_app_settings_store import OrgAppSettingsStore
from openhands.app_server.services.injector import Injector, InjectorState
from openhands.app_server.user.user_context import UserContext
from openhands.core.logger import openhands_logger as logger
@dataclass
class OrgAppSettingsService:
"""Service for organization app settings with injected dependencies."""
store: OrgAppSettingsStore
user_context: UserContext
async def get_org_app_settings(self) -> OrgAppSettingsResponse:
"""Get organization app settings.
User ID is obtained from the injected user_context.
Returns:
OrgAppSettingsResponse: The organization's app settings
Raises:
OrgNotFoundError: If current organization is not found
"""
user_id = await self.user_context.get_user_id()
logger.info(
'Getting organization app settings',
extra={'user_id': user_id},
)
org = await self.store.get_current_org_by_user_id(user_id)
if not org:
raise OrgNotFoundError('current')
return OrgAppSettingsResponse.from_org(org)
async def update_org_app_settings(
self,
update_data: OrgAppSettingsUpdate,
) -> OrgAppSettingsResponse:
"""Update organization app settings.
Only updates fields that are explicitly provided in update_data.
User ID is obtained from the injected user_context.
Session auto-commits at request end via DbSessionInjector.
Args:
update_data: The update data from the request
Returns:
OrgAppSettingsResponse: The updated organization's app settings
Raises:
OrgNotFoundError: If current organization is not found
"""
user_id = await self.user_context.get_user_id()
logger.info(
'Updating organization app settings',
extra={'user_id': user_id},
)
# Get current org first
org = await self.store.get_current_org_by_user_id(user_id)
if not org:
raise OrgNotFoundError('current')
# Check if any fields are provided
update_dict = update_data.model_dump(exclude_unset=True)
if not update_dict:
# No fields to update, just return current settings
logger.info(
'No fields to update in app settings',
extra={'user_id': user_id, 'org_id': str(org.id)},
)
return OrgAppSettingsResponse.from_org(org)
updated_org = await self.store.update_org_app_settings(
org_id=org.id,
update_data=update_data,
)
if not updated_org:
raise OrgNotFoundError('current')
logger.info(
'Organization app settings updated successfully',
extra={'user_id': user_id, 'updated_fields': list(update_dict.keys())},
)
return OrgAppSettingsResponse.from_org(updated_org)
class OrgAppSettingsServiceInjector(Injector[OrgAppSettingsService]):
"""Injector that composes store and user_context for OrgAppSettingsService."""
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[OrgAppSettingsService, None]:
# Local imports to avoid circular dependencies
from openhands.app_server.config import get_db_session, get_user_context
async with (
get_user_context(state, request) as user_context,
get_db_session(state, request) as db_session,
):
store = OrgAppSettingsStore(db_session=db_session)
yield OrgAppSettingsService(store=store, user_context=user_context)

View File

@@ -0,0 +1,130 @@
"""Service class for managing organization LLM settings.
Separates business logic from route handlers.
Uses dependency injection for db_session and user_context.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import AsyncGenerator
from fastapi import Request
from server.routes.org_models import (
OrgLLMSettingsResponse,
OrgLLMSettingsUpdate,
OrgNotFoundError,
)
from storage.org_llm_settings_store import OrgLLMSettingsStore
from openhands.app_server.services.injector import Injector, InjectorState
from openhands.app_server.user.user_context import UserContext
from openhands.core.logger import openhands_logger as logger
@dataclass
class OrgLLMSettingsService:
"""Service for org LLM settings with injected dependencies."""
store: OrgLLMSettingsStore
user_context: UserContext
async def get_org_llm_settings(self) -> OrgLLMSettingsResponse:
"""Get LLM settings for user's current organization.
User ID is obtained from the injected user_context.
Returns:
OrgLLMSettingsResponse: The organization's LLM settings
Raises:
ValueError: If user is not authenticated
OrgNotFoundError: If current organization not found
"""
user_id = await self.user_context.get_user_id()
if not user_id:
raise ValueError('User is not authenticated')
logger.info(
'Getting organization LLM settings',
extra={'user_id': user_id},
)
org = await self.store.get_current_org_by_user_id(user_id)
if not org:
raise OrgNotFoundError('No current organization')
return OrgLLMSettingsResponse.from_org(org)
async def update_org_llm_settings(
self,
update_data: OrgLLMSettingsUpdate,
) -> OrgLLMSettingsResponse:
"""Update LLM settings for user's current organization.
Only updates fields that are explicitly provided in update_data.
User ID is obtained from the injected user_context.
Session auto-commits at request end via DbSessionInjector.
Args:
update_data: The update data from the request
Returns:
OrgLLMSettingsResponse: The updated organization's LLM settings
Raises:
ValueError: If user is not authenticated
OrgNotFoundError: If current organization not found
"""
user_id = await self.user_context.get_user_id()
if not user_id:
raise ValueError('User is not authenticated')
logger.info(
'Updating organization LLM settings',
extra={'user_id': user_id},
)
# Check if any fields are provided
if not update_data.has_updates():
# No fields to update, just return current settings
return await self.get_org_llm_settings()
# Get user's current org first
org = await self.store.get_current_org_by_user_id(user_id)
if not org:
raise OrgNotFoundError('No current organization')
# Update the org LLM settings
updated_org = await self.store.update_org_llm_settings(
org_id=org.id,
update_data=update_data,
)
if not updated_org:
raise OrgNotFoundError(str(org.id))
logger.info(
'Organization LLM settings updated successfully',
extra={'user_id': user_id, 'org_id': str(org.id)},
)
return OrgLLMSettingsResponse.from_org(updated_org)
class OrgLLMSettingsServiceInjector(Injector[OrgLLMSettingsService]):
"""Injector that composes store and user_context for OrgLLMSettingsService."""
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[OrgLLMSettingsService, None]:
# Local imports to avoid circular dependencies
from openhands.app_server.config import get_db_session, get_user_context
async with (
get_user_context(state, request) as user_context,
get_db_session(state, request) as db_session,
):
store = OrgLLMSettingsStore(db_session=db_session)
yield OrgLLMSettingsService(store=store, user_context=user_context)

View File

@@ -4,13 +4,14 @@ import pickle
from datetime import datetime
from server.logger import logger
from sqlalchemy import and_, select
from storage.conversation_callback import (
CallbackStatus,
ConversationCallback,
ConversationCallbackProcessor,
)
from storage.conversation_work import ConversationWork
from storage.database import session_maker
from storage.database import a_session_maker, session_maker
from storage.stored_conversation_metadata import StoredConversationMetadata
from openhands.core.config import load_openhands_config
@@ -79,15 +80,16 @@ async def invoke_conversation_callbacks(
conversation_id: The conversation ID to process callbacks for
observation: The AgentStateChangedObservation that triggered the callback
"""
with session_maker() as session:
callbacks = (
session.query(ConversationCallback)
.filter(
ConversationCallback.conversation_id == conversation_id,
ConversationCallback.status == CallbackStatus.ACTIVE,
async with a_session_maker() as session:
result = await session.execute(
select(ConversationCallback).filter(
and_(
ConversationCallback.conversation_id == conversation_id,
ConversationCallback.status == CallbackStatus.ACTIVE,
)
)
.all()
)
callbacks = result.scalars().all()
for callback in callbacks:
try:
@@ -115,7 +117,7 @@ async def invoke_conversation_callbacks(
callback.status = CallbackStatus.ERROR
callback.updated_at = datetime.now()
session.commit()
await session.commit()
def update_conversation_metadata(conversation_id: str, content: dict):

View File

@@ -0,0 +1,33 @@
from datetime import datetime
from typing import Annotated
from pydantic import BaseModel, StringConstraints
class VerifiedModelCreate(BaseModel):
model_name: Annotated[
str,
StringConstraints(max_length=255),
]
provider: Annotated[
str,
StringConstraints(max_length=100),
]
is_enabled: bool = True
class VerifiedModel(VerifiedModelCreate):
id: int
created_at: datetime
updated_at: datetime
class VerifiedModelUpdate(BaseModel):
is_enabled: bool | None = None
class VerifiedModelPage(BaseModel):
"""Paginated response model for verified model list."""
items: list[VerifiedModel]
next_page_id: str | None = None

View File

@@ -0,0 +1,143 @@
"""API routes for managing verified LLM models (admin only)."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from server.email_validation import get_admin_user_id
from server.verified_models.verified_model_models import (
VerifiedModel,
VerifiedModelCreate,
VerifiedModelPage,
VerifiedModelUpdate,
)
from server.verified_models.verified_model_service import (
VerifiedModelService,
verified_model_store_dependency,
)
from openhands.app_server.config import get_db_session
from openhands.server.routes import public
from openhands.utils.llm import get_supported_llm_models
api_router = APIRouter(prefix='/api/admin/verified-models', tags=['Verified Models'])
@api_router.get('')
async def search_verified_models(
provider: str | None = None,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int, Query(title='The max number of results in the page', gt=0, le=100)
] = 100,
user_id: str = Depends(get_admin_user_id),
verified_model_service: VerifiedModelService = Depends(
verified_model_store_dependency
),
) -> VerifiedModelPage:
"""List all verified models, optionally filtered by provider."""
# Use SQL-level filtering and pagination
result = await verified_model_service.search_verified_models(
provider=provider,
enabled_only=False, # Admin sees all models including disabled
page_id=page_id,
limit=limit,
)
return result
@api_router.post('', status_code=201)
async def create_verified_model(
data: VerifiedModelCreate,
user_id: str = Depends(get_admin_user_id),
verified_model_service: VerifiedModelService = Depends(
verified_model_store_dependency
),
) -> VerifiedModel:
"""Create a new verified model."""
try:
model = await verified_model_service.create_verified_model(
model_name=data.model_name,
provider=data.provider,
is_enabled=data.is_enabled,
)
return model
except ValueError as ex:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(ex),
)
@api_router.put('/{provider}/{model_name:path}')
async def update_verified_model(
provider: str,
model_name: str,
data: VerifiedModelUpdate,
user_id: str = Depends(get_admin_user_id),
verified_model_service: VerifiedModelService = Depends(
verified_model_store_dependency
),
) -> VerifiedModel:
"""Update a verified model by provider and model name."""
model = await verified_model_service.update_verified_model(
model_name=model_name,
provider=provider,
is_enabled=data.is_enabled,
)
if not model:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Model {provider}/{model_name} not found',
)
return model
@api_router.delete('/{provider}/{model_name:path}')
async def delete_verified_model(
provider: str,
model_name: str,
user_id: str = Depends(get_admin_user_id),
verified_model_service: VerifiedModelService = Depends(
verified_model_store_dependency
),
) -> bool:
"""Delete a verified model by provider and model name."""
try:
await verified_model_service.delete_verified_model(
model_name=model_name, provider=provider
)
return True
except ValueError as ex:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(ex),
)
async def get_saas_llm_models_dependency(request: Request) -> list[str]:
"""SaaS implementation for the LLM models endpoint."""
async with get_db_session(request.state, request) as db_session:
# Prevent circular import
from openhands.server.shared import config
verified_model_service = VerifiedModelService(db_session)
page = await verified_model_service.search_verified_models(enabled_only=True)
if page.next_page_id:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Too many models defined in database',
)
verified_models = [f'{m.provider}/{m.model_name}' for m in page.items]
return get_supported_llm_models(config, verified_models)
# Override the default implementation with SaaS implementation
# This must be called after the app is created in saas_server.py
def override_llm_models_dependency(app):
"""Override the default LLM models implementation with SaaS version."""
app.dependency_overrides[public.get_llm_models_dependency] = (
get_saas_llm_models_dependency
)

View File

@@ -0,0 +1,242 @@
"""Store for managing verified LLM models in the database."""
from dataclasses import dataclass
from server.verified_models.verified_model_models import (
VerifiedModel,
VerifiedModelPage,
)
from sqlalchemy import (
Boolean,
Column,
DateTime,
Identity,
Integer,
String,
UniqueConstraint,
and_,
func,
select,
text,
)
from sqlalchemy.ext.asyncio import AsyncSession
from storage.base import Base
from openhands.app_server.config import depends_db_session
from openhands.core.logger import openhands_logger as logger
class StoredVerifiedModel(Base): # type: ignore
"""A verified LLM model available in the model selector.
The composite unique constraint on (model_name, provider) allows the same
model name to exist under different providers (e.g. 'claude-sonnet' under
both 'openhands' and 'anthropic').
"""
__tablename__ = 'verified_models'
__table_args__ = (
UniqueConstraint('model_name', 'provider', name='uq_verified_model_provider'),
)
id = Column(Integer, Identity(), primary_key=True)
model_name = Column(String(255), nullable=False)
provider = Column(String(100), nullable=False, index=True)
is_enabled = Column(
Boolean, nullable=False, default=True, server_default=text('true')
)
created_at = Column(DateTime, nullable=False, server_default=func.now())
updated_at = Column(
DateTime, nullable=False, server_default=func.now(), onupdate=func.now()
)
def verified_model(result: StoredVerifiedModel) -> VerifiedModel:
return VerifiedModel(
id=result.id,
model_name=result.model_name,
provider=result.provider,
is_enabled=result.is_enabled,
created_at=result.created_at,
updated_at=result.updated_at,
)
@dataclass
class VerifiedModelService:
"""Store for CRUD operations on verified models.
Follows the async pattern with db_session as an attribute.
"""
db_session: AsyncSession
async def search_verified_models(
self,
provider: str | None = None,
enabled_only: bool = True,
page_id: str | None = None,
limit: int = 100,
) -> VerifiedModelPage:
"""Search for verified models with optional filtering and pagination.
Args:
provider: Optional provider name to filter by (e.g., 'openhands', 'anthropic')
enabled_only: If True, only return enabled models (default: True)
page_id: Page id for pagination
limit: Maximum number of records to return
Returns:
SearchModelsResult containing items list and has_more flag
"""
query = select(StoredVerifiedModel)
# Build filters
filters = []
if provider:
filters.append(StoredVerifiedModel.provider == provider)
if enabled_only:
filters.append(StoredVerifiedModel.is_enabled.is_(True))
if filters:
query = query.where(and_(*filters))
# Order by provider, then model_name
query = query.order_by(
StoredVerifiedModel.provider, StoredVerifiedModel.model_name
)
# Fetch limit + 1 to check if there are more results
offset = int(page_id or '0')
query = query.offset(offset).limit(limit + 1)
result = await self.db_session.execute(query)
results = list(result.scalars().all())
has_more = len(results) > limit
next_page_id = None
# Return only the requested number of results
if has_more:
next_page_id = str(offset + limit)
results.pop()
items = [verified_model(result) for result in results]
return VerifiedModelPage(items=items, next_page_id=next_page_id)
async def get_model(self, model_name: str, provider: str) -> VerifiedModel | None:
"""Get a model by its composite key (model_name, provider).
Args:
model_name: The model identifier
provider: The provider name
"""
query = select(StoredVerifiedModel).where(
and_(
StoredVerifiedModel.model_name == model_name,
StoredVerifiedModel.provider == provider,
)
)
result = await self.db_session.execute(query)
return result.scalars().first()
async def create_verified_model(
self,
model_name: str,
provider: str,
is_enabled: bool = True,
) -> VerifiedModel:
"""Create a new verified model.
Args:
model_name: The model identifier
provider: The provider name
is_enabled: Whether the model is enabled (default True)
Raises:
ValueError: If a model with the same (model_name, provider) already exists
"""
existing_query = select(StoredVerifiedModel).where(
and_(
StoredVerifiedModel.model_name == model_name,
StoredVerifiedModel.provider == provider,
)
)
result = await self.db_session.execute(existing_query)
existing = result.scalars().first()
if existing:
raise ValueError(f'Model {provider}/{model_name} already exists')
model = StoredVerifiedModel(
model_name=model_name,
provider=provider,
is_enabled=is_enabled,
)
self.db_session.add(model)
await self.db_session.commit()
await self.db_session.refresh(model)
logger.info(f'Created verified model: {provider}/{model_name}')
return verified_model(model)
async def update_verified_model(
self,
model_name: str,
provider: str,
is_enabled: bool | None = None,
) -> VerifiedModel | None:
"""Update an existing verified model.
Args:
model_name: The model name to update
provider: The provider name
is_enabled: New enabled state (optional)
Returns:
The updated model if found, None otherwise
"""
query = select(StoredVerifiedModel).where(
and_(
StoredVerifiedModel.model_name == model_name,
StoredVerifiedModel.provider == provider,
)
)
result = await self.db_session.execute(query)
model = result.scalars().first()
if not model:
return None
if is_enabled is not None:
model.is_enabled = is_enabled
await self.db_session.commit()
await self.db_session.refresh(model)
logger.info(f'Updated verified model: {provider}/{model_name}')
return verified_model(model)
async def delete_verified_model(self, model_name: str, provider: str):
"""Delete a verified model.
Args:
model_name: The model name to delete
provider: The provider name
Returns:
True if deleted, False if not found
"""
query = select(StoredVerifiedModel).where(
and_(
StoredVerifiedModel.model_name == model_name,
StoredVerifiedModel.provider == provider,
)
)
result = await self.db_session.execute(query)
model = result.scalars().first()
if not model:
raise ValueError('Unknown model')
await self.db_session.delete(model)
await self.db_session.commit()
logger.info(f'Deleted verified model: {provider}/{model_name}')
def verified_model_store_dependency(db_session: AsyncSession = depends_db_session()):
return VerifiedModelService(db_session)

View File

@@ -5,20 +5,16 @@ import string
from dataclasses import dataclass
from datetime import UTC, datetime
from sqlalchemy import update
from sqlalchemy.orm import sessionmaker
from sqlalchemy import select, update
from storage.api_key import ApiKey
from storage.database import session_maker
from storage.database import a_session_maker
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.utils.async_utils import call_sync_from_async
@dataclass
class ApiKeyStore:
session_maker: sessionmaker
API_KEY_PREFIX = 'sk-oh-'
def generate_api_key(self, length: int = 32) -> str:
@@ -43,22 +39,8 @@ class ApiKeyStore:
api_key = self.generate_api_key()
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
await call_sync_from_async(
self._store_api_key, user_id, org_id, api_key, name, expires_at
)
return api_key
def _store_api_key(
self,
user_id: str,
org_id: str,
api_key: str,
name: str | None,
expires_at: datetime | None = None,
) -> None:
"""Store an existing API key in the database."""
with self.session_maker() as session:
async with a_session_maker() as session:
key_record = ApiKey(
key=api_key,
user_id=user_id,
@@ -67,14 +49,17 @@ class ApiKeyStore:
expires_at=expires_at,
)
session.add(key_record)
session.commit()
await session.commit()
def validate_api_key(self, api_key: str) -> str | None:
return api_key
async def validate_api_key(self, api_key: str) -> str | None:
"""Validate an API key and return the associated user_id if valid."""
now = datetime.now(UTC)
with self.session_maker() as session:
key_record = session.query(ApiKey).filter(ApiKey.key == api_key).first()
async with a_session_maker() as session:
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
key_record = result.scalars().first()
if not key_record:
return None
@@ -91,38 +76,40 @@ class ApiKeyStore:
return None
# Update last_used_at timestamp
session.execute(
await session.execute(
update(ApiKey)
.where(ApiKey.id == key_record.id)
.values(last_used_at=now)
)
session.commit()
await session.commit()
return key_record.user_id
def delete_api_key(self, api_key: str) -> bool:
async def delete_api_key(self, api_key: str) -> bool:
"""Delete an API key by the key value."""
with self.session_maker() as session:
key_record = session.query(ApiKey).filter(ApiKey.key == api_key).first()
async with a_session_maker() as session:
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
key_record = result.scalars().first()
if not key_record:
return False
session.delete(key_record)
session.commit()
await session.delete(key_record)
await session.commit()
return True
def delete_api_key_by_id(self, key_id: int) -> bool:
async def delete_api_key_by_id(self, key_id: int) -> bool:
"""Delete an API key by its ID."""
with self.session_maker() as session:
key_record = session.query(ApiKey).filter(ApiKey.id == key_id).first()
async with a_session_maker() as session:
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
key_record = result.scalars().first()
if not key_record:
return False
session.delete(key_record)
session.commit()
await session.delete(key_record)
await session.commit()
return True
@@ -130,64 +117,55 @@ class ApiKeyStore:
"""List all API keys for a user."""
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
return await call_sync_from_async(self._list_api_keys_from_db, user_id, org_id)
def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]:
with self.session_maker() as session:
keys: list[ApiKey] = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(
ApiKey.user_id == user_id, ApiKey.org_id == org_id
)
)
keys = result.scalars().all()
return [key for key in keys if key.name != 'MCP_API_KEY']
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
return await call_sync_from_async(
self._retrieve_mcp_api_key_from_db, user_id, org_id
)
def _retrieve_mcp_api_key_from_db(self, user_id: str, org_id: str) -> str | None:
with self.session_maker() as session:
keys: list[ApiKey] = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(
ApiKey.user_id == user_id, ApiKey.org_id == org_id
)
)
keys = result.scalars().all()
for key in keys:
if key.name == 'MCP_API_KEY':
return key.key
return None
def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
async def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
"""Retrieve an API key by name for a specific user."""
with self.session_maker() as session:
key_record = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
)
key_record = result.scalars().first()
return key_record.key if key_record else None
def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
async def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
"""Delete an API key by name for a specific user."""
with self.session_maker() as session:
key_record = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
)
key_record = result.scalars().first()
if not key_record:
return False
session.delete(key_record)
session.commit()
await session.delete(key_record)
await session.commit()
return True
@@ -195,4 +173,4 @@ class ApiKeyStore:
def get_instance(cls) -> ApiKeyStore:
"""Get an instance of the ApiKeyStore."""
logger.debug('api_key_store.get_instance')
return ApiKeyStore(session_maker)
return ApiKeyStore()

View File

@@ -7,7 +7,6 @@ from typing import Awaitable, Callable, Dict
from server.auth.auth_error import TokenRefreshError
from sqlalchemy import select, text, update
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from storage.auth_tokens import AuthTokens
from storage.database import a_session_maker
@@ -27,7 +26,6 @@ LOCK_TIMEOUT_SECONDS = 5
class AuthTokenStore:
keycloak_user_id: str
idp: ProviderType
a_session_maker: sessionmaker
@property
def identity_provider_value(self) -> str:
@@ -73,7 +71,7 @@ class AuthTokenStore:
access_token_expires_at: Expiration time for access token (seconds since epoch)
refresh_token_expires_at: Expiration time for refresh token (seconds since epoch)
"""
async with self.a_session_maker() as session:
async with a_session_maker() as session:
async with session.begin(): # Explicitly start a transaction
result = await session.execute(
select(AuthTokens).where(
@@ -138,7 +136,7 @@ class AuthTokenStore:
a 401 response to prompt the user to re-authenticate.
"""
# FAST PATH: Check without lock first to avoid unnecessary lock contention
async with self.a_session_maker() as session:
async with a_session_maker() as session:
result = await session.execute(
select(AuthTokens).filter(
AuthTokens.keycloak_user_id == self.keycloak_user_id,
@@ -167,7 +165,7 @@ class AuthTokenStore:
# SLOW PATH: Token needs refresh, acquire lock
try:
async with self.a_session_maker() as session:
async with a_session_maker() as session:
async with session.begin():
# Set a lock timeout to prevent indefinite blocking
# This ensures we don't hold connections forever if something goes wrong
@@ -300,6 +298,4 @@ class AuthTokenStore:
logger.debug(f'auth_token_store.get_instance::{keycloak_user_id}')
if keycloak_user_id:
keycloak_user_id = str(keycloak_user_id)
return AuthTokenStore(
keycloak_user_id=keycloak_user_id, idp=idp, a_session_maker=a_session_maker
)
return AuthTokenStore(keycloak_user_id=keycloak_user_id, idp=idp)

View File

@@ -1,14 +1,12 @@
from dataclasses import dataclass
from sqlalchemy import text
from sqlalchemy.orm import sessionmaker
from storage.database import a_session_maker
@dataclass
class BlockedEmailDomainStore:
session_maker: sessionmaker
def is_domain_blocked(self, domain: str) -> bool:
async def is_domain_blocked(self, domain: str) -> bool:
"""Check if a domain is blocked by querying the database directly.
This method uses SQL to efficiently check if the domain matches any blocked pattern:
@@ -21,9 +19,9 @@ class BlockedEmailDomainStore:
Returns:
True if the domain is blocked, False otherwise
"""
with self.session_maker() as session:
async with a_session_maker() as session:
# SQL query that handles both TLD patterns and full domain patterns
# TLD patterns (starting with '.'): check if domain ends with the pattern
# TLD patterns (starting with '.'): check if domain ends with it (case-insensitive)
# Full domain patterns: check for exact match or subdomain match
# All comparisons are case-insensitive using LOWER() to ensure consistent matching
query = text("""
@@ -41,5 +39,5 @@ class BlockedEmailDomainStore:
))
)
""")
result = session.execute(query, {'domain': domain}).scalar()
return bool(result)
result = await session.execute(query, {'domain': domain})
return bool(result.scalar())

View File

@@ -47,7 +47,11 @@ class DeviceCode(Base):
def is_expired(self) -> bool:
"""Check if the device code has expired."""
now = datetime.now(timezone.utc)
return now > self.expires_at
# Handle timezone-naive datetime from database by assuming it's UTC
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
def is_pending(self) -> bool:
"""Check if the device code is still pending authorization."""
@@ -85,8 +89,13 @@ class DeviceCode(Base):
if self.last_poll_time is None:
return False, self.current_interval
# Handle timezone-naive datetime from database by assuming it's UTC
last_poll_time = self.last_poll_time
if last_poll_time.tzinfo is None:
last_poll_time = last_poll_time.replace(tzinfo=timezone.utc)
# Calculate time since last poll
time_since_last_poll = (now - self.last_poll_time).total_seconds()
time_since_last_poll = (now - last_poll_time).total_seconds()
# Check if polling too fast
if time_since_last_poll < self.current_interval:

View File

@@ -1,19 +1,20 @@
"""Device code store for OAuth 2.0 Device Flow."""
from __future__ import annotations
import secrets
import string
from datetime import datetime, timedelta, timezone
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from storage.database import a_session_maker
from storage.device_code import DeviceCode
class DeviceCodeStore:
"""Store for managing OAuth 2.0 device codes."""
def __init__(self, session_maker):
self.session_maker = session_maker
def generate_user_code(self) -> str:
"""Generate a human-readable user code (8 characters, uppercase letters and digits)."""
# Use a mix of uppercase letters and digits, avoiding confusing characters
@@ -25,7 +26,7 @@ class DeviceCodeStore:
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for _ in range(128))
def create_device_code(
async def create_device_code(
self,
expires_in: int = 600, # 10 minutes default
max_attempts: int = 10,
@@ -58,11 +59,10 @@ class DeviceCodeStore:
)
try:
with self.session_maker() as session:
async with a_session_maker() as session:
session.add(device_code_entry)
session.commit()
session.refresh(device_code_entry)
session.expunge(device_code_entry) # Detach from session cleanly
await session.commit()
await session.refresh(device_code_entry)
return device_code_entry
except IntegrityError:
# Constraint violation - codes already exist, retry with new codes
@@ -72,25 +72,23 @@ class DeviceCodeStore:
f'Failed to generate unique device codes after {max_attempts} attempts'
)
def get_by_device_code(self, device_code: str) -> DeviceCode | None:
async def get_by_device_code(self, device_code: str) -> DeviceCode | None:
"""Get device code entry by device code."""
with self.session_maker() as session:
result = (
session.query(DeviceCode).filter_by(device_code=device_code).first()
async with a_session_maker() as session:
result = await session.execute(
select(DeviceCode).filter_by(device_code=device_code)
)
if result:
session.expunge(result) # Detach from session cleanly
return result
return result.scalars().first()
def get_by_user_code(self, user_code: str) -> DeviceCode | None:
async def get_by_user_code(self, user_code: str) -> DeviceCode | None:
"""Get device code entry by user code."""
with self.session_maker() as session:
result = session.query(DeviceCode).filter_by(user_code=user_code).first()
if result:
session.expunge(result) # Detach from session cleanly
return result
async with a_session_maker() as session:
result = await session.execute(
select(DeviceCode).filter_by(user_code=user_code)
)
return result.scalars().first()
def authorize_device_code(self, user_code: str, user_id: str) -> bool:
async def authorize_device_code(self, user_code: str, user_id: str) -> bool:
"""Authorize a device code.
Args:
@@ -100,10 +98,11 @@ class DeviceCodeStore:
Returns:
True if authorization was successful, False otherwise
"""
with self.session_maker() as session:
device_code_entry = (
session.query(DeviceCode).filter_by(user_code=user_code).first()
async with a_session_maker() as session:
result = await session.execute(
select(DeviceCode).filter_by(user_code=user_code)
)
device_code_entry = result.scalars().first()
if not device_code_entry:
return False
@@ -112,11 +111,11 @@ class DeviceCodeStore:
return False
device_code_entry.authorize(user_id)
session.commit()
await session.commit()
return True
def deny_device_code(self, user_code: str) -> bool:
async def deny_device_code(self, user_code: str) -> bool:
"""Deny a device code authorization.
Args:
@@ -125,10 +124,11 @@ class DeviceCodeStore:
Returns:
True if denial was successful, False otherwise
"""
with self.session_maker() as session:
device_code_entry = (
session.query(DeviceCode).filter_by(user_code=user_code).first()
async with a_session_maker() as session:
result = await session.execute(
select(DeviceCode).filter_by(user_code=user_code)
)
device_code_entry = result.scalars().first()
if not device_code_entry:
return False
@@ -137,11 +137,11 @@ class DeviceCodeStore:
return False
device_code_entry.deny()
session.commit()
await session.commit()
return True
def update_poll_time(
async def update_poll_time(
self, device_code: str, increase_interval: bool = False
) -> bool:
"""Update the poll time for a device code and optionally increase interval.
@@ -153,15 +153,16 @@ class DeviceCodeStore:
Returns:
True if update was successful, False otherwise
"""
with self.session_maker() as session:
device_code_entry = (
session.query(DeviceCode).filter_by(device_code=device_code).first()
async with a_session_maker() as session:
result = await session.execute(
select(DeviceCode).filter_by(device_code=device_code)
)
device_code_entry = result.scalars().first()
if not device_code_entry:
return False
device_code_entry.update_poll_time(increase_interval)
session.commit()
await session.commit()
return True

View File

@@ -5,7 +5,6 @@ from dataclasses import dataclass
from integrations.types import GitLabResourceType
from sqlalchemy import and_, asc, select, text, update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import sessionmaker
from storage.database import a_session_maker
from storage.gitlab_webhook import GitlabWebhook
@@ -14,8 +13,6 @@ from openhands.core.logger import openhands_logger as logger
@dataclass
class GitlabWebhookStore:
a_session_maker: sessionmaker = a_session_maker
@staticmethod
def determine_resource_type(
webhook: GitlabWebhook,
@@ -44,7 +41,7 @@ class GitlabWebhookStore:
if not project_details:
return
async with self.a_session_maker() as session:
async with a_session_maker() as session:
async with session.begin():
# Convert GitlabWebhook objects to dictionaries for the insert
# Using __dict__ and filtering out SQLAlchemy internal attributes and 'id'
@@ -88,7 +85,7 @@ class GitlabWebhookStore:
"""
resource_type, resource_id = GitlabWebhookStore.determine_resource_type(webhook)
async with self.a_session_maker() as session:
async with a_session_maker() as session:
async with session.begin():
stmt = (
update(GitlabWebhook).where(GitlabWebhook.project_id == resource_id)
@@ -122,7 +119,7 @@ class GitlabWebhookStore:
},
)
async with self.a_session_maker() as session:
async with a_session_maker() as session:
async with session.begin():
# Create query based on the identifier provided
if resource_type == GitLabResourceType.PROJECT:
@@ -185,7 +182,7 @@ class GitlabWebhookStore:
List of GitlabWebhook objects that need processing
"""
async with self.a_session_maker() as session:
async with a_session_maker() as session:
query = (
select(GitlabWebhook)
.where(GitlabWebhook.webhook_exists.is_(False))
@@ -201,7 +198,7 @@ class GitlabWebhookStore:
"""
Get's webhook secret given the webhook uuid and admin keycloak user id
"""
async with self.a_session_maker() as session:
async with a_session_maker() as session:
query = (
select(GitlabWebhook)
.where(
@@ -235,7 +232,7 @@ class GitlabWebhookStore:
Returns:
GitlabWebhook object if found, None otherwise
"""
async with self.a_session_maker() as session:
async with a_session_maker() as session:
if resource_type == GitLabResourceType.PROJECT:
query = select(GitlabWebhook).where(
GitlabWebhook.project_id == resource_id
@@ -263,7 +260,7 @@ class GitlabWebhookStore:
Returns:
Tuple of (project_webhook_map, group_webhook_map)
"""
async with self.a_session_maker() as session:
async with a_session_maker() as session:
project_webhook_map = {}
group_webhook_map = {}
@@ -303,7 +300,7 @@ class GitlabWebhookStore:
Returns:
True if webhook was reset, False if not found
"""
async with self.a_session_maker() as session:
async with a_session_maker() as session:
async with session.begin():
if resource_type == GitLabResourceType.PROJECT:
update_statement = (
@@ -348,4 +345,4 @@ class GitlabWebhookStore:
Returns:
An instance of GitlabWebhookStore
"""
return GitlabWebhookStore(a_session_maker)
return GitlabWebhookStore()

View File

@@ -3,7 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from storage.database import session_maker
from sqlalchemy import select
from storage.database import a_session_maker
from storage.jira_dc_conversation import JiraDcConversation
from storage.jira_dc_user import JiraDcUser
from storage.jira_dc_workspace import JiraDcWorkspace
@@ -24,7 +25,7 @@ class JiraDcIntegrationStore:
) -> JiraDcWorkspace:
"""Create a new Jira DC workspace with encrypted sensitive data."""
with session_maker() as session:
async with a_session_maker() as session:
workspace = JiraDcWorkspace(
name=name.lower(),
admin_user_id=admin_user_id,
@@ -34,8 +35,8 @@ class JiraDcIntegrationStore:
status=status,
)
session.add(workspace)
session.commit()
session.refresh(workspace)
await session.commit()
await session.refresh(workspace)
logger.info(f'[Jira DC] Created workspace {workspace.name}')
return workspace
@@ -48,11 +49,12 @@ class JiraDcIntegrationStore:
status: Optional[str] = None,
) -> JiraDcWorkspace:
"""Update an existing Jira DC workspace with encrypted sensitive data."""
with session_maker() as session:
async with a_session_maker() as session:
# Find existing workspace by ID
workspace = (
session.query(JiraDcWorkspace).filter(JiraDcWorkspace.id == id).first()
result = await session.execute(
select(JiraDcWorkspace).where(JiraDcWorkspace.id == id)
)
workspace = result.scalar_one_or_none()
if not workspace:
raise ValueError(f'Workspace with ID "{id}" not found')
@@ -69,8 +71,8 @@ class JiraDcIntegrationStore:
if status is not None:
workspace.status = status
session.commit()
session.refresh(workspace)
await session.commit()
await session.refresh(workspace)
logger.info(f'[Jira DC] Updated workspace {workspace.name}')
return workspace
@@ -91,10 +93,10 @@ class JiraDcIntegrationStore:
status=status,
)
with session_maker() as session:
async with a_session_maker() as session:
session.add(jira_dc_user)
session.commit()
session.refresh(jira_dc_user)
await session.commit()
await session.refresh(jira_dc_user)
logger.info(
f'[Jira DC] Created user {jira_dc_user.id} for workspace {jira_dc_workspace_id}'
@@ -103,94 +105,91 @@ class JiraDcIntegrationStore:
async def get_workspace_by_id(self, workspace_id: int) -> Optional[JiraDcWorkspace]:
"""Retrieve workspace by ID."""
with session_maker() as session:
return (
session.query(JiraDcWorkspace)
.filter(JiraDcWorkspace.id == workspace_id)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcWorkspace).where(JiraDcWorkspace.id == workspace_id)
)
return result.scalar_one_or_none()
async def get_workspace_by_name(
self, workspace_name: str
) -> Optional[JiraDcWorkspace]:
"""Retrieve workspace by name."""
with session_maker() as session:
return (
session.query(JiraDcWorkspace)
.filter(JiraDcWorkspace.name == workspace_name.lower())
.first()
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcWorkspace).where(
JiraDcWorkspace.name == workspace_name.lower()
)
)
return result.scalar_one_or_none()
async def get_user_by_active_workspace(
self, keycloak_user_id: str
) -> Optional[JiraDcUser]:
"""Retrieve user by Keycloak user ID."""
with session_maker() as session:
return (
session.query(JiraDcUser)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
JiraDcUser.keycloak_user_id == keycloak_user_id,
JiraDcUser.status == 'active',
)
.first()
)
return result.scalar_one_or_none()
async def get_user_by_keycloak_id_and_workspace(
self, keycloak_user_id: str, jira_dc_workspace_id: int
) -> Optional[JiraDcUser]:
"""Get Jira DC user by Keycloak user ID and workspace ID."""
with session_maker() as session:
return (
session.query(JiraDcUser)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
JiraDcUser.keycloak_user_id == keycloak_user_id,
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
)
.first()
)
return result.scalar_one_or_none()
async def get_active_user(
self, jira_dc_user_id: str, jira_dc_workspace_id: int
) -> Optional[JiraDcUser]:
"""Get Jira DC user by Keycloak user ID and workspace ID."""
with session_maker() as session:
return (
session.query(JiraDcUser)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
JiraDcUser.jira_dc_user_id == jira_dc_user_id,
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
JiraDcUser.status == 'active',
)
.first()
)
return result.scalar_one_or_none()
async def get_active_user_by_keycloak_id_and_workspace(
self, keycloak_user_id: str, jira_dc_workspace_id: int
) -> Optional[JiraDcUser]:
"""Get Jira DC user by Keycloak user ID and workspace ID."""
with session_maker() as session:
return (
session.query(JiraDcUser)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
JiraDcUser.keycloak_user_id == keycloak_user_id,
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
JiraDcUser.status == 'active',
)
.first()
)
return result.scalar_one_or_none()
async def update_user_integration_status(
self, keycloak_user_id: str, status: str
) -> JiraDcUser:
"""Update the status of a Jira DC user mapping."""
with session_maker() as session:
user = (
session.query(JiraDcUser)
.filter(JiraDcUser.keycloak_user_id == keycloak_user_id)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
JiraDcUser.keycloak_user_id == keycloak_user_id
)
)
user = result.scalar_one_or_none()
if not user:
raise ValueError(
@@ -198,37 +197,35 @@ class JiraDcIntegrationStore:
)
user.status = status
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)
logger.info(f'[Jira DC] Updated user {keycloak_user_id} status to {status}')
return user
async def deactivate_workspace(self, workspace_id: int):
"""Deactivate the workspace and all user links for a given workspace."""
with session_maker() as session:
users = (
session.query(JiraDcUser)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
JiraDcUser.jira_dc_workspace_id == workspace_id,
JiraDcUser.status == 'active',
)
.all()
)
users = result.scalars().all()
for user in users:
user.status = 'inactive'
session.add(user)
workspace = (
session.query(JiraDcWorkspace)
.filter(JiraDcWorkspace.id == workspace_id)
.first()
result = await session.execute(
select(JiraDcWorkspace).where(JiraDcWorkspace.id == workspace_id)
)
workspace = result.scalar_one_or_none()
if workspace:
workspace.status = 'inactive'
session.add(workspace)
session.commit()
await session.commit()
logger.info(
f'[Jira DC] Deactivated all user links for workspace {workspace_id}'
@@ -238,23 +235,22 @@ class JiraDcIntegrationStore:
self, jira_dc_conversation: JiraDcConversation
) -> None:
"""Create a new Jira DC conversation record."""
with session_maker() as session:
async with a_session_maker() as session:
session.add(jira_dc_conversation)
session.commit()
await session.commit()
async def get_user_conversations_by_issue_id(
self, issue_id: str, jira_dc_user_id: int
) -> JiraDcConversation | None:
"""Get a Jira DC conversation by issue ID and jira dc user ID."""
with session_maker() as session:
return (
session.query(JiraDcConversation)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcConversation).where(
JiraDcConversation.issue_id == issue_id,
JiraDcConversation.jira_dc_user_id == jira_dc_user_id,
)
.first()
)
return result.scalar_one_or_none()
@classmethod
def get_instance(cls) -> JiraDcIntegrationStore:

View File

@@ -3,7 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from storage.database import session_maker
from sqlalchemy import and_, select
from storage.database import a_session_maker
from storage.jira_conversation import JiraConversation
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
@@ -35,10 +36,10 @@ class JiraIntegrationStore:
status=status,
)
with session_maker() as session:
async with a_session_maker() as session:
session.add(workspace)
session.commit()
session.refresh(workspace)
await session.commit()
await session.refresh(workspace)
logger.info(f'[Jira] Created workspace {workspace.name}')
return workspace
@@ -53,11 +54,12 @@ class JiraIntegrationStore:
status: Optional[str] = None,
) -> JiraWorkspace:
"""Update an existing Jira workspace with encrypted sensitive data."""
with session_maker() as session:
async with a_session_maker() as session:
# Find existing workspace by ID
workspace = (
session.query(JiraWorkspace).filter(JiraWorkspace.id == id).first()
result = await session.execute(
select(JiraWorkspace).filter(JiraWorkspace.id == id)
)
workspace = result.scalars().first()
if not workspace:
raise ValueError(f'Workspace with ID "{id}" not found')
@@ -77,11 +79,11 @@ class JiraIntegrationStore:
if status is not None:
workspace.status = status
session.commit()
session.refresh(workspace)
await session.commit()
await session.refresh(workspace)
logger.info(f'[Jira] Updated workspace {workspace.name}')
return workspace
logger.info(f'[Jira] Updated workspace {workspace.name}')
return workspace
async def create_workspace_link(
self,
@@ -99,10 +101,10 @@ class JiraIntegrationStore:
status=status,
)
with session_maker() as session:
async with a_session_maker() as session:
session.add(jira_user)
session.commit()
session.refresh(jira_user)
await session.commit()
await session.refresh(jira_user)
logger.info(
f'[Jira] Created user {jira_user.id} for workspace {jira_workspace_id}'
@@ -111,75 +113,77 @@ class JiraIntegrationStore:
async def get_workspace_by_id(self, workspace_id: int) -> Optional[JiraWorkspace]:
"""Retrieve workspace by ID."""
with session_maker() as session:
return (
session.query(JiraWorkspace)
.filter(JiraWorkspace.id == workspace_id)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(JiraWorkspace).filter(JiraWorkspace.id == workspace_id)
)
return result.scalars().first()
async def get_workspace_by_name(self, workspace_name: str) -> JiraWorkspace | None:
"""Retrieve workspace by name."""
with session_maker() as session:
return (
session.query(JiraWorkspace)
.filter(JiraWorkspace.name == workspace_name.lower())
.first()
async with a_session_maker() as session:
result = await session.execute(
select(JiraWorkspace).filter(
JiraWorkspace.name == workspace_name.lower()
)
)
return result.scalars().first()
async def get_user_by_active_workspace(
self, keycloak_user_id: str
) -> Optional[JiraUser]:
"""Get Jira user by Keycloak user ID."""
with session_maker() as session:
return (
session.query(JiraUser)
.filter(
JiraUser.keycloak_user_id == keycloak_user_id,
JiraUser.status == 'active',
async with a_session_maker() as session:
result = await session.execute(
select(JiraUser).filter(
and_(
JiraUser.keycloak_user_id == keycloak_user_id,
JiraUser.status == 'active',
)
)
.first()
)
return result.scalars().first()
async def get_user_by_keycloak_id_and_workspace(
self, keycloak_user_id: str, jira_workspace_id: int
) -> Optional[JiraUser]:
"""Get Jira user by Keycloak user ID and workspace ID."""
with session_maker() as session:
return (
session.query(JiraUser)
.filter(
JiraUser.keycloak_user_id == keycloak_user_id,
JiraUser.jira_workspace_id == jira_workspace_id,
async with a_session_maker() as session:
result = await session.execute(
select(JiraUser).filter(
and_(
JiraUser.keycloak_user_id == keycloak_user_id,
JiraUser.jira_workspace_id == jira_workspace_id,
)
)
.first()
)
return result.scalars().first()
async def get_active_user(
self, jira_user_id: str, jira_workspace_id: int
) -> Optional[JiraUser]:
"""Get Jira user by Keycloak user ID and workspace ID."""
with session_maker() as session:
return (
session.query(JiraUser)
.filter(
JiraUser.jira_user_id == jira_user_id,
JiraUser.jira_workspace_id == jira_workspace_id,
JiraUser.status == 'active',
async with a_session_maker() as session:
result = await session.execute(
select(JiraUser).filter(
and_(
JiraUser.jira_user_id == jira_user_id,
JiraUser.jira_workspace_id == jira_workspace_id,
JiraUser.status == 'active',
)
)
.first()
)
return result.scalars().first()
async def update_user_integration_status(
self, keycloak_user_id: str, status: str
) -> JiraUser:
"""Update Jira user integration status."""
with session_maker() as session:
jira_user = (
session.query(JiraUser)
.filter(JiraUser.keycloak_user_id == keycloak_user_id)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(JiraUser).filter(JiraUser.keycloak_user_id == keycloak_user_id)
)
jira_user = result.scalars().first()
if not jira_user:
raise ValueError(
@@ -187,60 +191,61 @@ class JiraIntegrationStore:
)
jira_user.status = status
session.commit()
session.refresh(jira_user)
await session.commit()
await session.refresh(jira_user)
logger.info(f'[Jira] Updated user {keycloak_user_id} status to {status}')
return jira_user
async def deactivate_workspace(self, workspace_id: int):
"""Deactivate the workspace and all user links for a given workspace."""
with session_maker() as session:
users = (
session.query(JiraUser)
.filter(
JiraUser.jira_workspace_id == workspace_id,
JiraUser.status == 'active',
async with a_session_maker() as session:
result = await session.execute(
select(JiraUser).filter(
and_(
JiraUser.jira_workspace_id == workspace_id,
JiraUser.status == 'active',
)
)
.all()
)
users = result.scalars().all()
for user in users:
user.status = 'inactive'
session.add(user)
workspace = (
session.query(JiraWorkspace)
.filter(JiraWorkspace.id == workspace_id)
.first()
result = await session.execute(
select(JiraWorkspace).filter(JiraWorkspace.id == workspace_id)
)
workspace = result.scalars().first()
if workspace:
workspace.status = 'inactive'
session.add(workspace)
session.commit()
await session.commit()
logger.info(f'[Jira] Deactivated all user links for workspace {workspace_id}')
async def create_conversation(self, jira_conversation: JiraConversation) -> None:
"""Create a new Jira conversation record."""
with session_maker() as session:
async with a_session_maker() as session:
session.add(jira_conversation)
session.commit()
await session.commit()
async def get_user_conversations_by_issue_id(
self, issue_id: str, jira_user_id: int
) -> JiraConversation | None:
"""Get a Jira conversation by issue ID and jira user ID."""
with session_maker() as session:
return (
session.query(JiraConversation)
.filter(
JiraConversation.issue_id == issue_id,
JiraConversation.jira_user_id == jira_user_id,
async with a_session_maker() as session:
result = await session.execute(
select(JiraConversation).filter(
and_(
JiraConversation.issue_id == issue_id,
JiraConversation.jira_user_id == jira_user_id,
)
)
.first()
)
return result.scalars().first()
@classmethod
def get_instance(cls) -> JiraIntegrationStore:

View File

@@ -3,7 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from storage.database import session_maker
from sqlalchemy import select
from storage.database import a_session_maker
from storage.linear_conversation import LinearConversation
from storage.linear_user import LinearUser
from storage.linear_workspace import LinearWorkspace
@@ -35,10 +36,10 @@ class LinearIntegrationStore:
status=status,
)
with session_maker() as session:
async with a_session_maker() as session:
session.add(workspace)
session.commit()
session.refresh(workspace)
await session.commit()
await session.refresh(workspace)
logger.info(f'[Linear] Created workspace {workspace.name}')
return workspace
@@ -53,11 +54,12 @@ class LinearIntegrationStore:
status: Optional[str] = None,
) -> LinearWorkspace:
"""Update an existing Linear workspace with encrypted sensitive data."""
with session_maker() as session:
async with a_session_maker() as session:
# Find existing workspace by ID
workspace = (
session.query(LinearWorkspace).filter(LinearWorkspace.id == id).first()
result = await session.execute(
select(LinearWorkspace).where(LinearWorkspace.id == id)
)
workspace = result.scalar_one_or_none()
if not workspace:
raise ValueError(f'Workspace with ID "{id}" not found')
@@ -77,8 +79,8 @@ class LinearIntegrationStore:
if status is not None:
workspace.status = status
session.commit()
session.refresh(workspace)
await session.commit()
await session.refresh(workspace)
logger.info(f'[Linear] Updated workspace {workspace.name}')
return workspace
@@ -98,10 +100,10 @@ class LinearIntegrationStore:
status=status,
)
with session_maker() as session:
async with a_session_maker() as session:
session.add(linear_user)
session.commit()
session.refresh(linear_user)
await session.commit()
await session.refresh(linear_user)
logger.info(
f'[Linear] Created user {linear_user.id} for workspace {linear_workspace_id}'
@@ -110,77 +112,75 @@ class LinearIntegrationStore:
async def get_workspace_by_id(self, workspace_id: int) -> Optional[LinearWorkspace]:
"""Retrieve workspace by ID."""
with session_maker() as session:
return (
session.query(LinearWorkspace)
.filter(LinearWorkspace.id == workspace_id)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(LinearWorkspace).where(LinearWorkspace.id == workspace_id)
)
return result.scalar_one_or_none()
async def get_workspace_by_name(
self, workspace_name: str
) -> Optional[LinearWorkspace]:
"""Retrieve workspace by name."""
with session_maker() as session:
return (
session.query(LinearWorkspace)
.filter(LinearWorkspace.name == workspace_name.lower())
.first()
async with a_session_maker() as session:
result = await session.execute(
select(LinearWorkspace).where(
LinearWorkspace.name == workspace_name.lower()
)
)
return result.scalar_one_or_none()
async def get_user_by_active_workspace(
self, keycloak_user_id: str
) -> LinearUser | None:
"""Get Linear user by Keycloak user ID."""
with session_maker() as session:
return (
session.query(LinearUser)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(LinearUser).where(
LinearUser.keycloak_user_id == keycloak_user_id,
LinearUser.status == 'active',
)
.first()
)
return result.scalar_one_or_none()
async def get_user_by_keycloak_id_and_workspace(
self, keycloak_user_id: str, linear_workspace_id: int
) -> Optional[LinearUser]:
"""Get Linear user by Keycloak user ID and workspace ID."""
with session_maker() as session:
return (
session.query(LinearUser)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(LinearUser).where(
LinearUser.keycloak_user_id == keycloak_user_id,
LinearUser.linear_workspace_id == linear_workspace_id,
)
.first()
)
return result.scalar_one_or_none()
async def get_active_user(
self, linear_user_id: str, linear_workspace_id: int
) -> Optional[LinearUser]:
"""Get Linear user by Keycloak user ID and workspace ID."""
with session_maker() as session:
return (
session.query(LinearUser)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(LinearUser).where(
LinearUser.linear_user_id == linear_user_id,
LinearUser.linear_workspace_id == linear_workspace_id,
LinearUser.status == 'active',
)
.first()
)
return result.scalar_one_or_none()
async def update_user_integration_status(
self, keycloak_user_id: str, status: str
) -> LinearUser:
"""Update Linear user integration status."""
with session_maker() as session:
linear_user = (
session.query(LinearUser)
.filter(LinearUser.keycloak_user_id == keycloak_user_id)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(LinearUser).where(
LinearUser.keycloak_user_id == keycloak_user_id
)
)
linear_user = result.scalar_one_or_none()
if not linear_user:
raise ValueError(
@@ -188,38 +188,36 @@ class LinearIntegrationStore:
)
linear_user.status = status
session.commit()
session.refresh(linear_user)
await session.commit()
await session.refresh(linear_user)
logger.info(f'[Linear] Updated user {keycloak_user_id} status to {status}')
return linear_user
async def deactivate_workspace(self, workspace_id: int):
"""Deactivate the workspace and all user links for a given workspace."""
with session_maker() as session:
users = (
session.query(LinearUser)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(LinearUser).where(
LinearUser.linear_workspace_id == workspace_id,
LinearUser.status == 'active',
)
.all()
)
users = result.scalars().all()
for user in users:
user.status = 'inactive'
session.add(user)
workspace = (
session.query(LinearWorkspace)
.filter(LinearWorkspace.id == workspace_id)
.first()
result = await session.execute(
select(LinearWorkspace).where(LinearWorkspace.id == workspace_id)
)
workspace = result.scalar_one_or_none()
if workspace:
workspace.status = 'inactive'
session.add(workspace)
session.commit()
await session.commit()
logger.info(f'[Jira] Deactivated all user links for workspace {workspace_id}')
@@ -227,23 +225,22 @@ class LinearIntegrationStore:
self, linear_conversation: LinearConversation
) -> None:
"""Create a new Linear conversation record."""
with session_maker() as session:
async with a_session_maker() as session:
session.add(linear_conversation)
session.commit()
await session.commit()
async def get_user_conversations_by_issue_id(
self, issue_id: str, linear_user_id: int
) -> LinearConversation | None:
"""Get a Linear conversation by issue ID and linear user ID."""
with session_maker() as session:
return (
session.query(LinearConversation)
.filter(
async with a_session_maker() as session:
result = await session.execute(
select(LinearConversation).where(
LinearConversation.issue_id == issue_id,
LinearConversation.linear_user_id == linear_user_id,
)
.first()
)
return result.scalar_one_or_none()
@classmethod
def get_instance(cls) -> LinearIntegrationStore:

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
from dataclasses import dataclass
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from sqlalchemy import select
from storage.database import a_session_maker
from storage.stored_offline_token import StoredOfflineToken
from openhands.core.config.openhands_config import OpenHandsConfig
@@ -13,17 +13,17 @@ from openhands.core.logger import openhands_logger as logger
@dataclass
class OfflineTokenStore:
user_id: str
session_maker: sessionmaker
config: OpenHandsConfig
async def store_token(self, offline_token: str) -> None:
"""Store an offline token in the database."""
with self.session_maker() as session:
token_record = (
session.query(StoredOfflineToken)
.filter(StoredOfflineToken.user_id == self.user_id)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(StoredOfflineToken).where(
StoredOfflineToken.user_id == self.user_id
)
)
token_record = result.scalar_one_or_none()
if token_record:
token_record.offline_token = offline_token
@@ -32,16 +32,17 @@ class OfflineTokenStore:
user_id=self.user_id, offline_token=offline_token
)
session.add(token_record)
session.commit()
await session.commit()
async def load_token(self) -> str | None:
"""Load an offline token from the database."""
with self.session_maker() as session:
token_record = (
session.query(StoredOfflineToken)
.filter(StoredOfflineToken.user_id == self.user_id)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(StoredOfflineToken).where(
StoredOfflineToken.user_id == self.user_id
)
)
token_record = result.scalar_one_or_none()
if not token_record:
return None
@@ -56,4 +57,4 @@ class OfflineTokenStore:
logger.debug(f'offline_token_store.get_instance::{user_id}')
if user_id:
user_id = str(user_id)
return OfflineTokenStore(user_id, session_maker, config)
return OfflineTokenStore(user_id, config)

View File

@@ -0,0 +1,105 @@
"""Store class for managing organization app settings."""
from __future__ import annotations
from dataclasses import dataclass
from uuid import UUID
from server.constants import (
LITE_LLM_API_URL,
ORG_SETTINGS_VERSION,
get_default_litellm_model,
)
from server.routes.org_models import OrgAppSettingsUpdate
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from storage.org import Org
from storage.user import User
@dataclass
class OrgAppSettingsStore:
"""Store for organization app settings with injected db_session."""
db_session: AsyncSession
async def get_current_org_by_user_id(self, user_id: str) -> Org | None:
"""Get the current organization for a user.
Args:
user_id: The user's ID (Keycloak user ID)
Returns:
Org: The organization object, or None if not found
"""
# Get user with their current_org_id
result = await self.db_session.execute(
select(User).filter(User.id == UUID(user_id))
)
user = result.scalars().first()
if not user:
return None
org_id = user.current_org_id
if not org_id:
return None
# Get the organization
result = await self.db_session.execute(select(Org).filter(Org.id == org_id))
org = result.scalars().first()
if not org:
return None
return await self._validate_org_version(org)
async def _validate_org_version(self, org: Org) -> Org:
"""Check if we need to update org version.
Args:
org: The organization to validate
Returns:
Org: The validated (and potentially updated) organization
"""
if org.org_version < ORG_SETTINGS_VERSION:
org.org_version = ORG_SETTINGS_VERSION
org.default_llm_model = get_default_litellm_model()
org.llm_base_url = LITE_LLM_API_URL
await self.db_session.flush()
await self.db_session.refresh(org)
return org
async def update_org_app_settings(
self, org_id: UUID, update_data: OrgAppSettingsUpdate
) -> Org | None:
"""Update organization app settings.
Only updates fields that are explicitly provided in update_data.
Uses flush() - commit happens at request end via DbSessionInjector.
Args:
org_id: The organization's ID
update_data: Pydantic model with fields to update
Returns:
Org: The updated organization object, or None if not found
"""
result = await self.db_session.execute(
select(Org).filter(Org.id == org_id).with_for_update()
)
org = result.scalars().first()
if not org:
return None
# Update only explicitly provided fields
for field, value in update_data.model_dump(exclude_unset=True).items():
setattr(org, field, value)
# flush instead of commit - DbSessionInjector auto-commits at request end
await self.db_session.flush()
await self.db_session.refresh(org)
return org

View File

@@ -0,0 +1,83 @@
"""Store class for managing organization LLM settings."""
from __future__ import annotations
import uuid
from dataclasses import dataclass
from uuid import UUID
from server.routes.org_models import OrgLLMSettingsUpdate
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from storage.org import Org
from storage.org_member_store import OrgMemberStore
from storage.user import User
@dataclass
class OrgLLMSettingsStore:
"""Store for org LLM settings with injected db_session."""
db_session: AsyncSession
async def get_current_org_by_user_id(self, user_id: str) -> Org | None:
"""Get the user's current organization.
Args:
user_id: The user's ID (Keycloak user ID)
Returns:
Org: The user's current organization, or None if not found
"""
# First get the user to find their current_org_id
result = await self.db_session.execute(
select(User).filter(User.id == uuid.UUID(user_id))
)
user = result.scalars().first()
if not user or not user.current_org_id:
return None
# Then get the org
result = await self.db_session.execute(
select(Org).filter(Org.id == user.current_org_id)
)
return result.scalars().first()
async def update_org_llm_settings(
self, org_id: UUID, update_data: OrgLLMSettingsUpdate
) -> Org | None:
"""Update organization LLM settings.
Also propagates relevant settings to all org members.
Uses flush() - commit happens at request end via DbSessionInjector.
Args:
org_id: The organization's ID
update_data: Pydantic model with fields to update
Returns:
Org: The updated organization, or None if org not found
"""
result = await self.db_session.execute(
select(Org).filter(Org.id == org_id).with_for_update()
)
org = result.scalars().first()
if not org:
return None
# Apply updates to org (excludes llm_api_key which is member-only)
update_data.apply_to_org(org)
# Propagate relevant settings to all org members
member_updates = update_data.get_member_updates()
if member_updates:
await OrgMemberStore.update_all_members_llm_settings_async(
self.db_session, org_id, member_updates
)
# flush instead of commit - DbSessionInjector auto-commits at request end
await self.db_session.flush()
await self.db_session.refresh(org)
return org

View File

@@ -5,9 +5,12 @@ Store class for managing organization-member relationships.
from typing import Optional
from uuid import UUID
from sqlalchemy import func, select
from server.routes.org_models import OrgMemberLLMSettings
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker, session_maker
from storage.encrypt_utils import encrypt_value
from storage.org_member import OrgMember
from storage.user import User
from storage.user_settings import UserSettings
@@ -254,3 +257,28 @@ class OrgMemberStore:
members = members[:limit]
return members, has_more
@staticmethod
async def update_all_members_llm_settings_async(
session: AsyncSession,
org_id: UUID,
member_settings: OrgMemberLLMSettings,
) -> None:
"""Update LLM settings for all members of an organization.
Args:
session: Database session (passed from caller for transaction)
org_id: Organization ID
member_settings: Typed LLM settings to apply to all members
"""
# Build update values from non-None fields
values = member_settings.model_dump(exclude_none=True)
# Handle encrypted llm_api_key field - map to _llm_api_key column with encryption
if 'llm_api_key' in values:
raw_key = values.pop('llm_api_key')
values['_llm_api_key'] = encrypt_value(raw_key)
if values:
stmt = update(OrgMember).where(OrgMember.org_id == org_id).values(**values)
await session.execute(stmt)

View File

@@ -3,6 +3,7 @@ Service class for managing organization operations.
Separates business logic from route handlers.
"""
from typing import NoReturn
from uuid import UUID, uuid4
from uuid import UUID as parse_uuid
@@ -325,7 +326,7 @@ class OrgService:
user_id: str,
original_error: Exception,
error_message: str,
) -> None:
) -> NoReturn:
"""
Handle failure by cleaning up LiteLLM resources and raising appropriate error.

View File

@@ -10,10 +10,10 @@ from server.constants import (
ORG_SETTINGS_VERSION,
get_default_litellm_model,
)
from server.routes.org_models import OrphanedUserError
from sqlalchemy import text
from server.routes.org_models import OrgLLMSettingsUpdate, OrphanedUserError
from sqlalchemy import select, text
from sqlalchemy.orm import joinedload
from storage.database import session_maker
from storage.database import a_session_maker, session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
from storage.org_member import OrgMember
@@ -386,3 +386,47 @@ class OrgStore:
extra={'org_id': str(org_id), 'error': str(e)},
)
raise
@staticmethod
async def get_org_by_id_async(org_id: UUID) -> Org | None:
"""Get organization by ID (async version)."""
async with a_session_maker() as session:
result = await session.execute(select(Org).filter(Org.id == org_id))
org = result.scalars().first()
return OrgStore._validate_org_version(org) if org else None
@staticmethod
async def update_org_llm_settings_async(
org_id: UUID,
llm_settings: OrgLLMSettingsUpdate,
) -> Org | None:
"""Update organization LLM settings and propagate to members (async version).
Args:
org_id: Organization ID
llm_settings: Typed LLM settings update model
Returns:
Updated Org or None if not found
"""
from storage.org_member_store import OrgMemberStore
async with a_session_maker() as session:
result = await session.execute(select(Org).filter(Org.id == org_id))
org = result.scalars().first()
if not org:
return None
# Apply updates to org
llm_settings.apply_to_org(org)
# Propagate relevant settings to all org members
member_updates = llm_settings.get_member_updates()
if member_updates:
await OrgMemberStore.update_all_members_llm_settings_async(
session, org_id, member_updates
)
await session.commit()
await session.refresh(org)
return org

View File

@@ -10,7 +10,6 @@ from integrations.github.github_types import (
WorkflowRunStatus,
)
from sqlalchemy import and_, delete, select, update
from sqlalchemy.orm import sessionmaker
from storage.database import a_session_maker
from storage.proactive_convos import ProactiveConversation
@@ -20,8 +19,6 @@ from openhands.integrations.service_types import ProviderType
@dataclass
class ProactiveConversationStore:
a_session_maker: sessionmaker = a_session_maker
def get_repo_id(self, provider: ProviderType, repo_id):
return f'{provider.value}##{repo_id}'
@@ -51,7 +48,7 @@ class ProactiveConversationStore:
final_workflow_group = None
async with self.a_session_maker() as session:
async with a_session_maker() as session:
# Start an explicit transaction with row-level locking
async with session.begin():
# Get the existing proactive conversation entry with FOR UPDATE lock
@@ -142,7 +139,7 @@ class ProactiveConversationStore:
# Calculate the cutoff time (current time - older_than_minutes)
cutoff_time = datetime.now(UTC) - timedelta(minutes=older_than_minutes)
async with self.a_session_maker() as session:
async with a_session_maker() as session:
async with session.begin():
# Delete records older than the cutoff time
delete_stmt = delete(ProactiveConversation).where(
@@ -158,9 +155,9 @@ class ProactiveConversationStore:
@classmethod
async def get_instance(cls) -> ProactiveConversationStore:
"""Get an instance of the GitlabWebhookStore.
"""Get an instance of the ProactiveConversationStore.
Returns:
An instance of GitlabWebhookStore
An instance of ProactiveConversationStore
"""
return ProactiveConversationStore(a_session_maker)
return ProactiveConversationStore()

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
from dataclasses import dataclass
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from sqlalchemy import select
from storage.database import a_session_maker
from storage.stored_repository import StoredRepository
from openhands.core.config.openhands_config import OpenHandsConfig
@@ -11,12 +11,11 @@ from openhands.core.config.openhands_config import OpenHandsConfig
@dataclass
class RepositoryStore:
session_maker: sessionmaker
config: OpenHandsConfig
def store_projects(self, repositories: list[StoredRepository]) -> None:
async def store_projects(self, repositories: list[StoredRepository]) -> None:
"""
Store repositories in database
Store repositories in database (async version)
1. Make sure to store repositories if its ID doesn't exist
2. If repository ID already exists, make sure to only update the repo is_public and repo_name fields
@@ -26,17 +25,15 @@ class RepositoryStore:
if not repositories:
return
with self.session_maker() as session:
async with a_session_maker() as session:
# Extract all repo_ids to check
repo_ids = [r.repo_id for r in repositories]
# Get all existing repositories in a single query
existing_repos = {
r.repo_id: r
for r in session.query(StoredRepository).filter(
StoredRepository.repo_id.in_(repo_ids)
)
}
result = await session.execute(
select(StoredRepository).filter(StoredRepository.repo_id.in_(repo_ids))
)
existing_repos = {r.repo_id: r for r in result.scalars().all()}
# Process all repositories
for repo in repositories:
@@ -50,9 +47,9 @@ class RepositoryStore:
session.add(repo)
# Commit all changes
session.commit()
await session.commit()
@classmethod
def get_instance(cls, config: OpenHandsConfig) -> RepositoryStore:
"""Get an instance of the UserRepositoryStore."""
return RepositoryStore(session_maker, config)
return RepositoryStore(config)

View File

@@ -234,6 +234,8 @@ class SaasConversationStore(ConversationStore):
cls, config: OpenHandsConfig, user_id: str | None
) -> ConversationStore:
# user_id should not be None in SaaS, should we raise?
# Use async version since callers now use asyncio.run_coroutine_threadsafe()
# to dispatch to the main event loop where asyncpg connections work properly.
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id if user else None
return SaasConversationStore(str(user_id), org_id, session_maker)

View File

@@ -28,7 +28,7 @@ class SaasConversationValidator(ConversationValidator):
# Validate the API key and get the user_id
api_key_store = ApiKeyStore.get_instance()
user_id = api_key_store.validate_api_key(api_key)
user_id = await api_key_store.validate_api_key(api_key)
if not user_id:
logger.warning('Invalid API key')

View File

@@ -5,8 +5,8 @@ from base64 import b64decode, b64encode
from dataclasses import dataclass
from cryptography.fernet import Fernet
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from sqlalchemy import delete, select
from storage.database import a_session_maker
from storage.stored_custom_secrets import StoredCustomSecrets
from storage.user_store import UserStore
@@ -19,7 +19,6 @@ from openhands.storage.secrets.secrets_store import SecretsStore
@dataclass
class SaasSecretsStore(SecretsStore):
user_id: str
session_maker: sessionmaker
config: OpenHandsConfig
async def load(self) -> Secrets | None:
@@ -28,14 +27,15 @@ class SaasSecretsStore(SecretsStore):
user = await UserStore.get_user_by_id_async(self.user_id)
org_id = user.current_org_id if user else None
with self.session_maker() as session:
async with a_session_maker() as session:
# Fetch all secrets for the given user ID
query = session.query(StoredCustomSecrets).filter(
query = select(StoredCustomSecrets).filter(
StoredCustomSecrets.keycloak_user_id == self.user_id
)
if org_id is not None:
query = query.filter(StoredCustomSecrets.org_id == org_id)
settings = query.all()
result = await session.execute(query)
settings = result.scalars().all()
if not settings:
return Secrets()
@@ -54,12 +54,15 @@ class SaasSecretsStore(SecretsStore):
async def store(self, item: Secrets):
user = await UserStore.get_user_by_id_async(self.user_id)
org_id = user.current_org_id
with self.session_maker() as session:
async with a_session_maker() as session:
# Incoming secrets are always the most updated ones
# Delete all existing records and override with incoming ones
session.query(StoredCustomSecrets).filter(
StoredCustomSecrets.keycloak_user_id == self.user_id
).delete()
await session.execute(
delete(StoredCustomSecrets).filter(
StoredCustomSecrets.keycloak_user_id == self.user_id
)
)
# Prepare the new secrets data
kwargs = item.model_dump(context={'expose_secrets': True})
@@ -89,7 +92,7 @@ class SaasSecretsStore(SecretsStore):
)
session.add(new_secret)
session.commit()
await session.commit()
def _decrypt_kwargs(self, kwargs: dict):
fernet = self._fernet()
@@ -133,4 +136,4 @@ class SaasSecretsStore(SecretsStore):
if not user_id:
raise Exception('SaasSecretsStore cannot be constructed with no user_id')
logger.debug(f'saas_secrets_store.get_instance::{user_id}')
return SaasSecretsStore(user_id, session_maker, config)
return SaasSecretsStore(user_id, config)

View File

@@ -10,8 +10,9 @@ from cryptography.fernet import Fernet
from pydantic import SecretStr
from server.constants import LITE_LLM_API_URL
from server.logger import logger
from sqlalchemy.orm import joinedload, sessionmaker
from storage.database import session_maker
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker
from storage.lite_llm_manager import LiteLlmManager, get_openhands_cloud_key_alias
from storage.org import Org
from storage.org_member import OrgMember
@@ -23,26 +24,24 @@ from storage.user_store import UserStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.server.settings import Settings
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.llm import is_openhands_model
@dataclass
class SaasSettingsStore(SettingsStore):
user_id: str
session_maker: sessionmaker
config: OpenHandsConfig
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
def _get_user_settings_by_keycloak_id(
async def _get_user_settings_by_keycloak_id_async(
self, keycloak_user_id: str, session=None
) -> UserSettings | None:
"""
Get UserSettings by keycloak_user_id.
Get UserSettings by keycloak_user_id (async version).
Args:
keycloak_user_id: The keycloak user ID to search for
session: Optional existing database session. If not provided, creates a new one.
session: Optional existing async database session. If not provided, creates a new one.
Returns:
UserSettings object if found, None otherwise
@@ -50,27 +49,26 @@ class SaasSettingsStore(SettingsStore):
if not keycloak_user_id:
return None
def _get_settings():
if session:
# Use provided session
return (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
.first()
if session:
# Use provided session
result = await session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == keycloak_user_id
)
else:
# Create new session
with self.session_maker() as new_session:
return (
new_session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
.first()
)
return result.scalars().first()
else:
# Create new session
async with a_session_maker() as new_session:
result = await new_session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == keycloak_user_id
)
return _get_settings()
)
return result.scalars().first()
async def load(self) -> Settings | None:
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
user = await UserStore.get_user_by_id_async(self.user_id)
if not user:
logger.error(f'User not found for ID {self.user_id}')
return None
@@ -83,7 +81,7 @@ class SaasSettingsStore(SettingsStore):
break
if not org_member or not org_member.llm_api_key:
return None
org = OrgStore.get_org_by_id(org_id)
org = await OrgStore.get_org_by_id_async(org_id)
if not org:
logger.error(
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
@@ -122,21 +120,22 @@ class SaasSettingsStore(SettingsStore):
return settings
async def store(self, item: Settings):
with self.session_maker() as session:
async with a_session_maker() as session:
if not item:
return None
user = (
session.query(User)
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(self.user_id))
).first()
)
user = result.scalars().first()
if not user:
# Check if we need to migrate from user_settings
user_settings = None
with session_maker() as session:
user_settings = self._get_user_settings_by_keycloak_id(
self.user_id, session
async with a_session_maker() as new_session:
user_settings = await self._get_user_settings_by_keycloak_id_async(
self.user_id, new_session
)
if user_settings:
user = await UserStore.migrate_user(self.user_id, user_settings)
@@ -154,7 +153,8 @@ class SaasSettingsStore(SettingsStore):
if not org_member or not org_member.llm_api_key:
return None
org: Org = session.query(Org).filter(Org.id == org_id).first()
result = await session.execute(select(Org).filter(Org.id == org_id))
org = result.scalars().first()
if not org:
logger.error(
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
@@ -173,7 +173,7 @@ class SaasSettingsStore(SettingsStore):
if hasattr(model, key):
setattr(model, key, value)
session.commit()
await session.commit()
@classmethod
async def get_instance(
@@ -182,7 +182,7 @@ class SaasSettingsStore(SettingsStore):
user_id: str, # type: ignore[override]
) -> SaasSettingsStore:
logger.debug(f'saas_settings_store.get_instance::{user_id}')
return SaasSettingsStore(user_id, session_maker, config)
return SaasSettingsStore(user_id, config)
def _should_encrypt(self, key):
return key in self.ENCRYPT_VALUES

View File

@@ -2,38 +2,35 @@ from __future__ import annotations
from dataclasses import dataclass
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from sqlalchemy import select
from storage.database import a_session_maker
from storage.slack_conversation import SlackConversation
@dataclass
class SlackConversationStore:
session_maker: sessionmaker
async def get_slack_conversation(
self, channel_id: str, parent_id: str
) -> SlackConversation | None:
"""Get a slack conversation by channel_id and message_ts.
Both parameters are required to match for a conversation to be returned.
"""
with session_maker() as session:
conversation = (
session.query(SlackConversation)
.filter(SlackConversation.channel_id == channel_id)
.filter(SlackConversation.parent_id == parent_id)
.first()
async with a_session_maker() as session:
result = await session.execute(
select(SlackConversation).where(
SlackConversation.channel_id == channel_id,
SlackConversation.parent_id == parent_id,
)
)
return conversation
return result.scalar_one_or_none()
async def create_slack_conversation(
self, slack_converstion: SlackConversation
) -> None:
with self.session_maker() as session:
async with a_session_maker() as session:
session.merge(slack_converstion)
session.commit()
await session.commit()
@classmethod
def get_instance(cls) -> SlackConversationStore:
return SlackConversationStore(session_maker)
return SlackConversationStore()

View File

@@ -32,6 +32,7 @@ class SlackTeamStore:
# Store the token
session.add(slack_team)
session.commit()
return slack_team
@classmethod
def get_instance(cls):

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass
import sqlalchemy
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from sqlalchemy import select
from storage.database import a_session_maker
from storage.user_repo_map import UserRepositoryMap
from openhands.core.config.openhands_config import OpenHandsConfig
@@ -12,12 +12,11 @@ from openhands.core.config.openhands_config import OpenHandsConfig
@dataclass
class UserRepositoryMapStore:
session_maker: sessionmaker
config: OpenHandsConfig
def store_user_repo_mappings(self, mappings: list[UserRepositoryMap]) -> None:
async def store_user_repo_mappings(self, mappings: list[UserRepositoryMap]) -> None:
"""
Store user-repository mappings in database
Store user-repository mappings in database (async version)
1. Make sure to store mappings if they don't exist
2. If a mapping already exists (same user_id and repo_id), update the admin field
@@ -30,18 +29,20 @@ class UserRepositoryMapStore:
if not mappings:
return
with self.session_maker() as session:
async with a_session_maker() as session:
# Extract all user_id/repo_id pairs to check
mapping_keys = [(m.user_id, m.repo_id) for m in mappings]
# Get all existing mappings in a single query
existing_mappings = {
(m.user_id, m.repo_id): m
for m in session.query(UserRepositoryMap).filter(
result = await session.execute(
select(UserRepositoryMap).filter(
sqlalchemy.tuple_(
UserRepositoryMap.user_id, UserRepositoryMap.repo_id
).in_(mapping_keys)
)
)
existing_mappings = {
(m.user_id, m.repo_id): m for m in result.scalars().all()
}
# Process all mappings
@@ -56,9 +57,9 @@ class UserRepositoryMapStore:
session.add(mapping)
# Commit all changes
session.commit()
await session.commit()
@classmethod
def get_instance(cls, config: OpenHandsConfig) -> UserRepositoryMapStore:
"""Get an instance of the UserRepositoryMapStore."""
return UserRepositoryMapStore(session_maker, config)
return UserRepositoryMapStore(config)

View File

@@ -227,7 +227,7 @@ class UserStore:
'user_store:migrate_user:calling_stripe_migrate_customer',
extra={'user_id': user_id},
)
await migrate_customer(session, user_id, org)
await migrate_customer(user_id, org)
logger.debug(
'user_store:migrate_user:done_stripe_migrate_customer',
extra={'user_id': user_id},

View File

@@ -1,39 +0,0 @@
"""SQLAlchemy model for verified LLM models."""
from sqlalchemy import (
Boolean,
Column,
DateTime,
Identity,
Integer,
String,
UniqueConstraint,
func,
text,
)
from storage.base import Base
class VerifiedModel(Base): # type: ignore
"""A verified LLM model available in the model selector.
The composite unique constraint on (model_name, provider) allows the same
model name to exist under different providers (e.g. 'claude-sonnet' under
both 'openhands' and 'anthropic').
"""
__tablename__ = 'verified_models'
__table_args__ = (
UniqueConstraint('model_name', 'provider', name='uq_verified_model_provider'),
)
id = Column(Integer, Identity(), primary_key=True)
model_name = Column(String(255), nullable=False)
provider = Column(String(100), nullable=False, index=True)
is_enabled = Column(
Boolean, nullable=False, default=True, server_default=text('true')
)
created_at = Column(DateTime, nullable=False, server_default=func.now())
updated_at = Column(
DateTime, nullable=False, server_default=func.now(), onupdate=func.now()
)

View File

@@ -1,187 +0,0 @@
"""Store for managing verified LLM models in the database."""
from sqlalchemy import and_
from storage.database import session_maker
from storage.verified_model import VerifiedModel
from openhands.core.logger import openhands_logger as logger
class VerifiedModelStore:
"""Store for CRUD operations on verified models.
Follows the project convention of static methods with session_maker()
(see UserStore, OrgMemberStore for reference).
"""
@staticmethod
def get_enabled_models() -> list[VerifiedModel]:
"""Get all enabled models.
Returns:
list[VerifiedModel]: All models where is_enabled is True
"""
with session_maker() as session:
return (
session.query(VerifiedModel)
.filter(VerifiedModel.is_enabled.is_(True))
.order_by(VerifiedModel.provider, VerifiedModel.model_name)
.all()
)
@staticmethod
def get_models_by_provider(provider: str) -> list[VerifiedModel]:
"""Get all enabled models for a specific provider.
Args:
provider: The provider name (e.g., 'openhands', 'anthropic')
"""
with session_maker() as session:
return (
session.query(VerifiedModel)
.filter(
and_(
VerifiedModel.provider == provider,
VerifiedModel.is_enabled.is_(True),
)
)
.order_by(VerifiedModel.model_name)
.all()
)
@staticmethod
def get_all_models() -> list[VerifiedModel]:
"""Get all models (including disabled)."""
with session_maker() as session:
return (
session.query(VerifiedModel)
.order_by(VerifiedModel.provider, VerifiedModel.model_name)
.all()
)
@staticmethod
def get_model(model_name: str, provider: str) -> VerifiedModel | None:
"""Get a model by its composite key (model_name, provider).
Args:
model_name: The model identifier
provider: The provider name
"""
with session_maker() as session:
return (
session.query(VerifiedModel)
.filter(
and_(
VerifiedModel.model_name == model_name,
VerifiedModel.provider == provider,
)
)
.first()
)
@staticmethod
def create_model(
model_name: str, provider: str, is_enabled: bool = True
) -> VerifiedModel:
"""Create a new verified model.
Args:
model_name: The model identifier
provider: The provider name
is_enabled: Whether the model is enabled (default True)
Raises:
ValueError: If a model with the same (model_name, provider) already exists
"""
with session_maker() as session:
existing = (
session.query(VerifiedModel)
.filter(
and_(
VerifiedModel.model_name == model_name,
VerifiedModel.provider == provider,
)
)
.first()
)
if existing:
raise ValueError(f'Model {provider}/{model_name} already exists')
model = VerifiedModel(
model_name=model_name,
provider=provider,
is_enabled=is_enabled,
)
session.add(model)
session.commit()
session.refresh(model)
logger.info(f'Created verified model: {provider}/{model_name}')
return model
@staticmethod
def update_model(
model_name: str,
provider: str,
is_enabled: bool | None = None,
) -> VerifiedModel | None:
"""Update an existing verified model.
Args:
model_name: The model name to update
provider: The provider name
is_enabled: New enabled state (optional)
Returns:
The updated model if found, None otherwise
"""
with session_maker() as session:
model = (
session.query(VerifiedModel)
.filter(
and_(
VerifiedModel.model_name == model_name,
VerifiedModel.provider == provider,
)
)
.first()
)
if not model:
return None
if is_enabled is not None:
model.is_enabled = is_enabled
session.commit()
session.refresh(model)
logger.info(f'Updated verified model: {provider}/{model_name}')
return model
@staticmethod
def delete_model(model_name: str, provider: str) -> bool:
"""Delete a verified model.
Args:
model_name: The model name to delete
provider: The provider name
Returns:
True if deleted, False if not found
"""
with session_maker() as session:
model = (
session.query(VerifiedModel)
.filter(
and_(
VerifiedModel.model_name == model_name,
VerifiedModel.provider == provider,
)
)
.first()
)
if not model:
return False
session.delete(model)
session.commit()
logger.info(f'Deleted verified model: {provider}/{model_name}')
return True

View File

@@ -4,11 +4,20 @@ from uuid import UUID
import pytest
from server.constants import ORG_SETTINGS_VERSION
from server.verified_models.verified_model_service import (
StoredVerifiedModel, # noqa: F401
)
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import sessionmaker
from storage.base import Base
# Anything not loaded here may not have a table created for it.
from storage.api_key import ApiKey # noqa: F401
from storage.base import Base
from storage.billing_session import BillingSession
from storage.conversation_work import ConversationWork
from storage.device_code import DeviceCode # noqa: F401
@@ -25,12 +34,20 @@ from storage.stored_conversation_metadata_saas import (
from storage.stored_offline_token import StoredOfflineToken
from storage.stripe_customer import StripeCustomer
from storage.user import User
from storage.verified_model import VerifiedModel # noqa: F401
@pytest.fixture(scope='function')
def db_path(tmp_path):
"""Create a unique temp file path for each test."""
return str(tmp_path / 'test.db')
@pytest.fixture
def engine():
engine = create_engine('sqlite:///:memory:')
def engine(db_path):
"""Create a sync engine with tables using file-based DB."""
engine = create_engine(
f'sqlite:///{db_path}', connect_args={'check_same_thread': False}
)
Base.metadata.create_all(engine)
return engine
@@ -40,6 +57,36 @@ def session_maker(engine):
return sessionmaker(bind=engine)
@pytest.fixture
def async_engine(db_path):
"""Create an async engine using the SAME file-based database."""
async_engine = create_async_engine(
f'sqlite+aiosqlite:///{db_path}',
connect_args={'check_same_thread': False},
)
async def create_tables():
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Run the async function synchronously
import asyncio
asyncio.run(create_tables())
return async_engine
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker bound to the async engine."""
async_session_maker = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
return async_session_maker
def add_minimal_fixtures(session_maker):
with session_maker() as session:
session.add(

View File

@@ -7,7 +7,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from integrations.jira.jira_manager import JiraManager
from integrations.jira.jira_payload import JiraEventType, JiraWebhookPayload
from integrations.models import Message, SourceType
from openhands.server.types import (
LLMAuthenticationError,
@@ -274,9 +273,8 @@ class TestSendMessage:
return_value=mock_response
)
message = Message(source=SourceType.JIRA, message='Test message')
result = await jira_manager.send_message(
message,
'Test message',
'PROJ-123',
'cloud-123',
'service@test.com',

View File

@@ -0,0 +1,268 @@
"""
Tests for JiraPayloadParser.
These tests verify the parsing behavior of Jira webhook payloads,
including the handling of optional fields like user_email which
may not be present in webhook payloads from Jira.
"""
import pytest
from integrations.jira.jira_payload import (
JiraEventType,
JiraPayloadError,
JiraPayloadParser,
JiraPayloadSkipped,
JiraPayloadSuccess,
)
@pytest.fixture
def parser():
"""Create a JiraPayloadParser with standard OpenHands labels."""
return JiraPayloadParser(oh_label='openhands', inline_oh_label='@openhands')
@pytest.fixture
def valid_label_payload():
"""Create a valid jira:issue_updated payload with OpenHands label."""
return {
'webhookEvent': 'jira:issue_updated',
'issue': {
'id': '12345',
'key': 'TEST-123',
'self': 'https://test.atlassian.net/rest/api/2/issue/12345',
},
'user': {
'displayName': 'Test User',
'accountId': 'account-123',
'emailAddress': 'test@example.com',
},
'changelog': {
'items': [
{
'field': 'labels',
'toString': 'openhands',
}
]
},
}
@pytest.fixture
def valid_comment_payload():
"""Create a valid comment_created payload with OpenHands mention."""
return {
'webhookEvent': 'comment_created',
'issue': {
'id': '12345',
'key': 'TEST-123',
'self': 'https://test.atlassian.net/rest/api/2/issue/12345',
},
'comment': {
'body': '@openhands please fix this bug',
'author': {
'displayName': 'Test User',
'accountId': 'account-123',
'emailAddress': 'test@example.com',
},
},
}
class TestUserEmailOptional:
"""Tests verifying user_email is optional in webhook payloads.
Jira webhooks may not include emailAddress in the user data.
The parser should accept payloads without this field.
"""
def test_label_event_succeeds_without_email_address(
self, parser, valid_label_payload
):
"""Verify label event parsing succeeds when emailAddress is missing."""
# Arrange - remove emailAddress from user data
del valid_label_payload['user']['emailAddress']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.user_email == ''
assert result.payload.display_name == 'Test User'
assert result.payload.account_id == 'account-123'
def test_comment_event_succeeds_without_email_address(
self, parser, valid_comment_payload
):
"""Verify comment event parsing succeeds when emailAddress is missing."""
# Arrange - remove emailAddress from author data
del valid_comment_payload['comment']['author']['emailAddress']
# Act
result = parser.parse(valid_comment_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.user_email == ''
assert result.payload.display_name == 'Test User'
assert result.payload.account_id == 'account-123'
def test_user_email_preserved_when_present(self, parser, valid_label_payload):
"""Verify user_email is captured when emailAddress is present."""
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.user_email == 'test@example.com'
class TestRequiredFieldValidation:
"""Tests verifying required fields are still validated."""
def test_missing_issue_id_returns_error(self, parser, valid_label_payload):
"""Verify parsing fails when issue.id is missing."""
# Arrange
del valid_label_payload['issue']['id']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadError)
assert 'issue.id' in result.error
def test_missing_issue_key_returns_error(self, parser, valid_label_payload):
"""Verify parsing fails when issue.key is missing."""
# Arrange
del valid_label_payload['issue']['key']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadError)
assert 'issue.key' in result.error
def test_missing_display_name_returns_error(self, parser, valid_label_payload):
"""Verify parsing fails when user.displayName is missing."""
# Arrange
del valid_label_payload['user']['displayName']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadError)
assert 'displayName' in result.error
def test_missing_account_id_returns_error(self, parser, valid_label_payload):
"""Verify parsing fails when user.accountId is missing."""
# Arrange
del valid_label_payload['user']['accountId']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadError)
assert 'accountId' in result.error
def test_missing_issue_self_url_returns_error(self, parser, valid_label_payload):
"""Verify parsing fails when issue.self URL is missing."""
# Arrange
del valid_label_payload['issue']['self']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadError)
assert 'workspace_name' in result.error or 'base_api_url' in result.error
class TestEventTypeDetection:
"""Tests for webhook event type detection."""
def test_issue_updated_with_label_returns_labeled_ticket(
self, parser, valid_label_payload
):
"""Verify jira:issue_updated with label is detected as LABELED_TICKET."""
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.event_type == JiraEventType.LABELED_TICKET
def test_comment_created_with_mention_returns_comment_mention(
self, parser, valid_comment_payload
):
"""Verify comment_created with mention is detected as COMMENT_MENTION."""
# Act
result = parser.parse(valid_comment_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.event_type == JiraEventType.COMMENT_MENTION
def test_unhandled_event_type_returns_skipped(self, parser):
"""Verify unknown event types are skipped."""
# Arrange
payload = {'webhookEvent': 'jira:issue_deleted'}
# Act
result = parser.parse(payload)
# Assert
assert isinstance(result, JiraPayloadSkipped)
assert 'Unhandled' in result.skip_reason
class TestLabelFiltering:
"""Tests for OpenHands label filtering."""
def test_label_event_without_openhands_label_skipped(
self, parser, valid_label_payload
):
"""Verify label events without OpenHands label are skipped."""
# Arrange - change label to something else
valid_label_payload['changelog']['items'][0]['toString'] = 'other-label'
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadSkipped)
assert 'openhands' in result.skip_reason
class TestCommentFiltering:
"""Tests for OpenHands comment mention filtering."""
def test_comment_without_mention_skipped(self, parser, valid_comment_payload):
"""Verify comments without OpenHands mention are skipped."""
# Arrange - remove mention from comment body
valid_comment_payload['comment']['body'] = 'Please fix this bug'
# Act
result = parser.parse(valid_comment_payload)
# Assert
assert isinstance(result, JiraPayloadSkipped)
assert '@openhands' in result.skip_reason
class TestWorkspaceExtraction:
"""Tests for workspace name extraction from issue URL."""
def test_workspace_name_extracted_from_self_url(self, parser, valid_label_payload):
"""Verify workspace name is extracted from issue self URL."""
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.workspace_name == 'test.atlassian.net'
assert result.payload.base_api_url == 'https://test.atlassian.net'

View File

@@ -738,7 +738,7 @@ class TestStartJob:
# Should send error message about re-login
jira_dc_manager.send_message.assert_called_once()
call_args = jira_dc_manager.send_message.call_args[0]
assert 'Please re-login' in call_args[0].message
assert 'Please re-login' in call_args[0]
@pytest.mark.asyncio
async def test_start_job_llm_authentication_error(
@@ -763,7 +763,7 @@ class TestStartJob:
# Should send error message about LLM API key
jira_dc_manager.send_message.assert_called_once()
call_args = jira_dc_manager.send_message.call_args[0]
assert 'valid LLM API key' in call_args[0].message
assert 'valid LLM API key' in call_args[0]
@pytest.mark.asyncio
async def test_start_job_session_expired_error(
@@ -788,8 +788,8 @@ class TestStartJob:
# Should send error message about session expired
jira_dc_manager.send_message.assert_called_once()
call_args = jira_dc_manager.send_message.call_args[0]
assert 'session has expired' in call_args[0].message
assert 'login again' in call_args[0].message
assert 'session has expired' in call_args[0]
assert 'login again' in call_args[0]
@pytest.mark.asyncio
async def test_start_job_unexpected_error(
@@ -814,7 +814,7 @@ class TestStartJob:
# Should send generic error message
jira_dc_manager.send_message.assert_called_once()
call_args = jira_dc_manager.send_message.call_args[0]
assert 'unexpected error' in call_args[0].message
assert 'unexpected error' in call_args[0]
@pytest.mark.asyncio
async def test_start_job_send_message_fails(
@@ -943,9 +943,8 @@ class TestSendMessage:
return_value=mock_response
)
message = Message(source=SourceType.JIRA_DC, message='Test message')
result = await jira_dc_manager.send_message(
message, 'PROJ-123', 'https://jira.company.com', 'bearer_token'
'Test message', 'PROJ-123', 'https://jira.company.com', 'bearer_token'
)
assert result == {'id': 'comment_id'}
@@ -1014,7 +1013,7 @@ class TestSendRepoSelectionComment:
jira_dc_manager.send_message.assert_called_once()
call_args = jira_dc_manager.send_message.call_args[0]
assert 'which repository to work with' in call_args[0].message
assert 'which repository to work with' in call_args[0]
@pytest.mark.asyncio
async def test_send_repo_selection_comment_send_fails(

View File

@@ -18,9 +18,11 @@ from openhands.core.schema.agent import AgentState
class TestJiraDcNewConversationView:
"""Tests for JiraDcNewConversationView"""
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
"""Test _get_instructions method"""
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
instructions, user_msg = await new_conversation_view._get_instructions(
mock_jinja_env
)
assert instructions == 'Test Jira DC instructions template'
assert 'PROJ-123' in user_msg
@@ -83,9 +85,9 @@ class TestJiraDcNewConversationView:
class TestJiraDcExistingConversationView:
"""Tests for JiraDcExistingConversationView"""
def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
async def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
"""Test _get_instructions method"""
instructions, user_msg = existing_conversation_view._get_instructions(
instructions, user_msg = await existing_conversation_view._get_instructions(
mock_jinja_env
)

View File

@@ -802,7 +802,7 @@ class TestStartJob:
# Should send error message about re-login
linear_manager.send_message.assert_called_once()
call_args = linear_manager.send_message.call_args[0]
assert 'Please re-login' in call_args[0].message
assert 'Please re-login' in call_args[0]
@pytest.mark.asyncio
async def test_start_job_llm_authentication_error(
@@ -828,7 +828,7 @@ class TestStartJob:
# Should send error message about LLM API key
linear_manager.send_message.assert_called_once()
call_args = linear_manager.send_message.call_args[0]
assert 'valid LLM API key' in call_args[0].message
assert 'valid LLM API key' in call_args[0]
@pytest.mark.asyncio
async def test_start_job_session_expired_error(
@@ -854,8 +854,8 @@ class TestStartJob:
# Should send error message about session expired
linear_manager.send_message.assert_called_once()
call_args = linear_manager.send_message.call_args[0]
assert 'session has expired' in call_args[0].message
assert 'login again' in call_args[0].message
assert 'session has expired' in call_args[0]
assert 'login again' in call_args[0]
@pytest.mark.asyncio
async def test_start_job_unexpected_error(
@@ -881,7 +881,7 @@ class TestStartJob:
# Should send generic error message
linear_manager.send_message.assert_called_once()
call_args = linear_manager.send_message.call_args[0]
assert 'unexpected error' in call_args[0].message
assert 'unexpected error' in call_args[0]
@pytest.mark.asyncio
async def test_start_job_send_message_fails(
@@ -1049,8 +1049,9 @@ class TestSendMessage:
linear_manager._query_api = AsyncMock(return_value=mock_response)
message = Message(source=SourceType.LINEAR, message='Test message')
result = await linear_manager.send_message(message, 'issue_id', 'api_key')
result = await linear_manager.send_message(
'Test message', 'issue_id', 'api_key'
)
assert result == mock_response
linear_manager._query_api.assert_called_once()
@@ -1114,7 +1115,7 @@ class TestSendRepoSelectionComment:
linear_manager.send_message.assert_called_once()
call_args = linear_manager.send_message.call_args[0]
assert 'which repository to work with' in call_args[0].message
assert 'which repository to work with' in call_args[0]
@pytest.mark.asyncio
async def test_send_repo_selection_comment_send_fails(

View File

@@ -18,9 +18,11 @@ from openhands.core.schema.agent import AgentState
class TestLinearNewConversationView:
"""Tests for LinearNewConversationView"""
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
"""Test _get_instructions method"""
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
instructions, user_msg = await new_conversation_view._get_instructions(
mock_jinja_env
)
assert instructions == 'Test instructions template'
assert 'TEST-123' in user_msg
@@ -83,9 +85,9 @@ class TestLinearNewConversationView:
class TestLinearExistingConversationView:
"""Tests for LinearExistingConversationView"""
def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
async def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
"""Test _get_instructions method"""
instructions, user_msg = existing_conversation_view._get_instructions(
instructions, user_msg = await existing_conversation_view._get_instructions(
mock_jinja_env
)

View File

@@ -263,7 +263,9 @@ class TestPausedSandboxResumption:
@patch('openhands.app_server.config.get_httpx_client')
@patch('openhands.app_server.event_callback.util.ensure_running_sandbox')
@patch('openhands.app_server.event_callback.util.get_agent_server_url_from_sandbox')
@patch.object(SlackUpdateExistingConversationView, '_get_instructions')
@patch.object(
SlackUpdateExistingConversationView, '_get_instructions', new_callable=AsyncMock
)
async def test_paused_sandbox_resumption(
self,
mock_get_instructions,

View File

@@ -34,7 +34,6 @@ async def test_send_comment_to_jira_success(mock_jira_manager, processor):
)
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_manager.send_message = AsyncMock()
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
# Action
await processor._send_comment_to_jira('This is a summary.')
@@ -130,7 +129,6 @@ async def test_call_sends_summary_to_jira(
return_value=mock_workspace
)
mock_jira_manager.send_message = AsyncMock()
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.jira_callback_processor.asyncio.create_task'
@@ -200,7 +198,6 @@ async def test_send_comment_to_jira_api_error(mock_jira_manager, processor):
)
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
# Action - should not raise exception, but handle it gracefully
await processor._send_comment_to_jira('This is a summary.')
@@ -328,18 +325,15 @@ async def test_send_comment_to_jira_message_construction(mock_jira_manager, proc
)
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_manager.send_message = AsyncMock()
mock_outgoing_message = MagicMock()
mock_jira_manager.create_outgoing_message.return_value = mock_outgoing_message
test_message = 'This is a test summary message.'
# Action
await processor._send_comment_to_jira(test_message)
# Assert
mock_jira_manager.create_outgoing_message.assert_called_once_with(msg=test_message)
# Assert - send_message now receives the string directly
mock_jira_manager.send_message.assert_called_once_with(
mock_outgoing_message,
test_message,
issue_key='TEST-123',
jira_cloud_id='cloud123',
svc_acc_email='service@test.com',
@@ -386,7 +380,6 @@ async def test_call_creates_background_task_for_sending(
return_value=mock_workspace
)
mock_jira_manager.send_message = AsyncMock()
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.jira_callback_processor.asyncio.create_task'

View File

@@ -32,7 +32,6 @@ async def test_send_comment_to_jira_dc_success(mock_jira_dc_manager, processor):
)
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_dc_manager.send_message = AsyncMock()
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
# Action
await processor._send_comment_to_jira_dc('This is a summary.')
@@ -125,7 +124,6 @@ async def test_call_sends_summary_to_jira_dc(
return_value=mock_workspace
)
mock_jira_dc_manager.send_message = AsyncMock()
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.jira_dc_callback_processor.asyncio.create_task'
@@ -200,7 +198,6 @@ async def test_send_comment_to_jira_dc_api_error(mock_jira_dc_manager, processor
)
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_dc_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
# Action - should not raise exception, but handle it gracefully
await processor._send_comment_to_jira_dc('This is a summary.')
@@ -328,20 +325,15 @@ async def test_send_comment_to_jira_dc_message_construction(
)
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_dc_manager.send_message = AsyncMock()
mock_outgoing_message = MagicMock()
mock_jira_dc_manager.create_outgoing_message.return_value = mock_outgoing_message
test_message = 'This is a test summary message.'
# Action
await processor._send_comment_to_jira_dc(test_message)
# Assert
mock_jira_dc_manager.create_outgoing_message.assert_called_once_with(
msg=test_message
)
# Assert - send_message now receives the string directly
mock_jira_dc_manager.send_message.assert_called_once_with(
mock_outgoing_message,
test_message,
issue_key='TEST-123',
base_api_url='https://test-jira-dc.company.com',
svc_acc_api_key='decrypted_key',
@@ -384,7 +376,6 @@ async def test_call_creates_background_task_for_sending(
return_value=mock_workspace
)
mock_jira_dc_manager.send_message = AsyncMock()
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.jira_dc_callback_processor.asyncio.create_task'

View File

@@ -32,7 +32,6 @@ async def test_send_comment_to_linear_success(mock_linear_manager, processor):
)
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_linear_manager.send_message = AsyncMock()
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
# Action
await processor._send_comment_to_linear('This is a summary.')
@@ -125,7 +124,6 @@ async def test_call_sends_summary_to_linear(
return_value=mock_workspace
)
mock_linear_manager.send_message = AsyncMock()
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.linear_callback_processor.asyncio.create_task'
@@ -200,7 +198,6 @@ async def test_send_comment_to_linear_api_error(mock_linear_manager, processor):
)
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_linear_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
# Action - should not raise exception, but handle it gracefully
await processor._send_comment_to_linear('This is a summary.')
@@ -328,20 +325,15 @@ async def test_send_comment_to_linear_message_construction(
)
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_linear_manager.send_message = AsyncMock()
mock_outgoing_message = MagicMock()
mock_linear_manager.create_outgoing_message.return_value = mock_outgoing_message
test_message = 'This is a test summary message.'
# Action
await processor._send_comment_to_linear(test_message)
# Assert
mock_linear_manager.create_outgoing_message.assert_called_once_with(
msg=test_message
)
# Assert - send_message now receives the string directly
mock_linear_manager.send_message.assert_called_once_with(
mock_outgoing_message,
test_message,
'TEST-123', # issue_id
'decrypted_key', # api_key
)
@@ -383,7 +375,6 @@ async def test_call_creates_background_task_for_sending(
return_value=mock_workspace
)
mock_linear_manager.send_message = AsyncMock()
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.linear_callback_processor.asyncio.create_task'

View File

@@ -16,8 +16,15 @@ from storage.device_code import DeviceCode
@pytest.fixture
def mock_device_code_store():
"""Mock device code store."""
return MagicMock()
"""Mock device code store with async methods."""
mock = MagicMock()
mock.create_device_code = AsyncMock()
mock.get_by_device_code = AsyncMock()
mock.get_by_user_code = AsyncMock()
mock.authorize_device_code = AsyncMock()
mock.deny_device_code = AsyncMock()
mock.update_poll_time = AsyncMock()
return mock
@pytest.fixture
@@ -54,7 +61,7 @@ class TestDeviceAuthorization:
expires_at=datetime.now(UTC) + timedelta(minutes=10),
current_interval=5, # Default interval
)
mock_store.create_device_code.return_value = mock_device
mock_store.create_device_code = AsyncMock(return_value=mock_device)
result = await device_authorization(mock_request)
@@ -76,7 +83,7 @@ class TestDeviceAuthorization:
expires_at=datetime.now(UTC) + timedelta(minutes=10),
current_interval=15, # Increased interval from previous rate limiting
)
mock_store.create_device_code.return_value = mock_device
mock_store.create_device_code = AsyncMock(return_value=mock_device)
result = await device_authorization(mock_request)
@@ -113,10 +120,10 @@ class TestDeviceToken:
mock_device.status = status
# Mock rate limiting - return False (not too fast) and default interval
mock_device.check_rate_limit.return_value = (False, 5)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
else:
mock_store.get_by_device_code.return_value = None
mock_store.get_by_device_code = AsyncMock(return_value=None)
result = await device_token(device_code=device_code)
@@ -142,12 +149,14 @@ class TestDeviceToken:
)
# Mock rate limiting - return False (not too fast) and default interval
mock_device.check_rate_limit.return_value = (False, 5)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
# Mock API key retrieval
# Mock API key retrieval - use AsyncMock for async method
mock_api_key_store = MagicMock()
mock_api_key_store.retrieve_api_key_by_name.return_value = 'test-api-key'
mock_api_key_store.retrieve_api_key_by_name = AsyncMock(
return_value='test-api-key'
)
mock_api_key_class.get_instance.return_value = mock_api_key_store
result = await device_token(device_code=device_code)
@@ -176,7 +185,7 @@ class TestDeviceVerificationAuthenticated:
self, mock_store, mock_api_key_class
):
"""Test verification with invalid device code."""
mock_store.get_by_user_code.return_value = None
mock_store.get_by_user_code = AsyncMock(return_value=None)
with pytest.raises(HTTPException):
await device_verification_authenticated(
@@ -189,7 +198,7 @@ class TestDeviceVerificationAuthenticated:
"""Test verification with already processed device code."""
mock_device = MagicMock()
mock_device.is_pending.return_value = False
mock_store.get_by_user_code.return_value = mock_device
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
with pytest.raises(HTTPException):
await device_verification_authenticated(
@@ -203,8 +212,8 @@ class TestDeviceVerificationAuthenticated:
# Mock device code
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = True
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
mock_store.authorize_device_code = AsyncMock(return_value=True)
# Mock API key store with async create_api_key
mock_api_key_store = MagicMock()
@@ -248,15 +257,17 @@ class TestDeviceVerificationAuthenticated:
mock_device2.is_pending.return_value = True
# Configure mock store to return appropriate device for each user_code
def get_by_user_code_side_effect(user_code):
async def get_by_user_code_side_effect(user_code):
if user_code == device1_code:
return mock_device1
elif user_code == device2_code:
return mock_device2
return None
mock_store.get_by_user_code.side_effect = get_by_user_code_side_effect
mock_store.authorize_device_code.return_value = True
mock_store.get_by_user_code = AsyncMock(
side_effect=get_by_user_code_side_effect
)
mock_store.authorize_device_code = AsyncMock(return_value=True)
# Authenticate first device
result1 = await device_verification_authenticated(
@@ -305,8 +316,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=None, # First poll
current_interval=5,
)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -336,8 +347,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=last_poll,
current_interval=5,
)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -367,8 +378,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=last_poll,
current_interval=5,
)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -399,8 +410,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=last_poll,
current_interval=15, # Already increased from previous slow_down
)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -430,8 +441,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=last_poll,
current_interval=58, # Near maximum of 60
)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -457,8 +468,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=last_poll,
current_interval=5,
)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -487,8 +498,10 @@ class TestDeviceVerificationTransactionIntegrity:
# Mock device code
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = False # Authorization fails
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
mock_store.authorize_device_code = AsyncMock(
return_value=False
) # Authorization fails
# Mock API key store with async create_api_key
mock_api_key_store = MagicMock()
@@ -519,9 +532,11 @@ class TestDeviceVerificationTransactionIntegrity:
# Mock device code
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = True # Authorization succeeds
mock_store.deny_device_code.return_value = True # Cleanup succeeds
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
mock_store.authorize_device_code = AsyncMock(
return_value=True
) # Authorization succeeds
mock_store.deny_device_code = AsyncMock(return_value=True) # Cleanup succeeds
# Mock API key store to fail on creation (async)
mock_api_key_store = MagicMock()
@@ -559,10 +574,12 @@ class TestDeviceVerificationTransactionIntegrity:
# Mock device code
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = True # Authorization succeeds
mock_store.deny_device_code.side_effect = Exception(
'Cleanup failed'
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
mock_store.authorize_device_code = AsyncMock(
return_value=True
) # Authorization succeeds
mock_store.deny_device_code = AsyncMock(
side_effect=Exception('Cleanup failed')
) # Cleanup fails
# Mock API key store to fail on creation (async)
@@ -595,8 +612,11 @@ class TestDeviceVerificationTransactionIntegrity:
# Mock device code
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = True # Authorization succeeds
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
mock_store.authorize_device_code = AsyncMock(
return_value=True
) # Authorization succeeds
mock_store.deny_device_code = AsyncMock()
# Mock API key store with async create_api_key
mock_api_key_store = MagicMock()

View File

@@ -11,41 +11,37 @@ import httpx
import pytest
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.testclient import TestClient
from server.email_validation import get_admin_user_id
from server.routes.org_models import (
CannotModifySelfError,
InsufficientPermissionError,
InvalidRoleError,
LastOwnerError,
LiteLLMIntegrationError,
MeResponse,
OrgAppSettingsResponse,
OrgAppSettingsUpdate,
OrgAuthorizationError,
OrgDatabaseError,
OrgMemberNotFoundError,
OrgMemberPage,
OrgMemberResponse,
OrgMemberUpdate,
OrgNameExistsError,
OrgNotFoundError,
OrphanedUserError,
RoleNotFoundError,
)
from server.routes.orgs import (
get_me,
get_org_members,
org_router,
remove_org_member,
update_org_member,
)
from storage.org import Org
# Mock database before imports
with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from server.email_validation import get_admin_user_id
from server.routes.org_models import (
CannotModifySelfError,
InsufficientPermissionError,
InvalidRoleError,
LastOwnerError,
LiteLLMIntegrationError,
MeResponse,
OrgAuthorizationError,
OrgDatabaseError,
OrgMemberNotFoundError,
OrgMemberPage,
OrgMemberResponse,
OrgMemberUpdate,
OrgNameExistsError,
OrgNotFoundError,
OrphanedUserError,
RoleNotFoundError,
)
from server.routes.orgs import (
get_me,
get_org_members,
org_router,
remove_org_member,
update_org_member,
)
from storage.org import Org
from openhands.server.user_auth import get_user_id
from openhands.server.user_auth import get_user_id
# Test user ID constant (must be a valid UUID string)
TEST_USER_ID = str(uuid.uuid4())
@@ -3424,3 +3420,421 @@ async def test_switch_org_database_error(mock_app_with_get_user_id):
# Assert
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert 'Failed to switch organization' in response.json()['detail']
# =============================================================================
# Tests for App Settings Endpoints
# =============================================================================
@pytest.fixture
def mock_member_role():
"""Create a mock member role for authorization tests."""
mock_role = MagicMock()
mock_role.name = 'member'
return mock_role
@pytest.mark.asyncio
async def test_get_org_app_settings_success(
mock_app_with_get_user_id, mock_member_role
):
"""
GIVEN: Authenticated user with MANAGE_APPLICATION_SETTINGS permission
WHEN: GET /api/organizations/app is called
THEN: App settings are returned with 200 status
"""
# Arrange
mock_response = OrgAppSettingsResponse(
enable_proactive_conversation_starters=True,
enable_solvability_analysis=False,
max_budget_per_task=10.0,
)
with (
patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_member_role),
),
patch(
'server.routes.orgs.OrgAppSettingsService.get_org_app_settings',
AsyncMock(return_value=mock_response),
),
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.get('/api/organizations/app')
# Assert
assert response.status_code == status.HTTP_200_OK
response_data = response.json()
assert response_data['enable_proactive_conversation_starters'] is True
assert response_data['enable_solvability_analysis'] is False
assert response_data['max_budget_per_task'] == 10.0
@pytest.mark.asyncio
async def test_get_org_app_settings_with_null_values(
mock_app_with_get_user_id, mock_member_role
):
"""
GIVEN: Organization has null app settings values
WHEN: GET /api/organizations/app is called
THEN: Default values are returned where applicable
"""
# Arrange
# OrgAppSettingsResponse.from_org() handles defaults, so we test the response model
mock_response = OrgAppSettingsResponse(
enable_proactive_conversation_starters=True, # Default when None in Org
enable_solvability_analysis=None,
max_budget_per_task=None,
)
with (
patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_member_role),
),
patch(
'server.routes.orgs.OrgAppSettingsService.get_org_app_settings',
AsyncMock(return_value=mock_response),
),
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.get('/api/organizations/app')
# Assert
assert response.status_code == status.HTTP_200_OK
response_data = response.json()
# enable_proactive_conversation_starters defaults to True when None
assert response_data['enable_proactive_conversation_starters'] is True
assert response_data['enable_solvability_analysis'] is None
assert response_data['max_budget_per_task'] is None
@pytest.mark.asyncio
async def test_get_org_app_settings_not_found(
mock_app_with_get_user_id, mock_member_role
):
"""
GIVEN: User has no current organization
WHEN: GET /api/organizations/app is called
THEN: 404 Not Found error is returned
"""
# Arrange
with (
patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_member_role),
),
patch(
'server.routes.orgs.OrgAppSettingsService.get_org_app_settings',
AsyncMock(side_effect=OrgNotFoundError('current')),
),
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.get('/api/organizations/app')
# Assert
assert response.status_code == status.HTTP_404_NOT_FOUND
assert 'not found' in response.json()['detail'].lower()
@pytest.mark.asyncio
async def test_get_org_app_settings_user_not_member(mock_app_with_get_user_id):
"""
GIVEN: User is not a member of any organization
WHEN: GET /api/organizations/app is called
THEN: 403 Forbidden error is returned
"""
# Arrange - user has no role (not a member)
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=None),
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.get('/api/organizations/app')
# Assert
assert response.status_code == status.HTTP_403_FORBIDDEN
assert 'not a member' in response.json()['detail'].lower()
@pytest.mark.asyncio
async def test_update_org_app_settings_success(
mock_app_with_get_user_id, mock_member_role
):
"""
GIVEN: Valid update data and authenticated user
WHEN: POST /api/organizations/app is called
THEN: Updated app settings are returned with 200 status
"""
# Arrange
mock_response = OrgAppSettingsResponse(
enable_proactive_conversation_starters=False,
enable_solvability_analysis=True,
max_budget_per_task=25.0,
)
with (
patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_member_role),
),
patch(
'server.routes.orgs.OrgAppSettingsService.update_org_app_settings',
AsyncMock(return_value=mock_response),
) as mock_update,
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.post(
'/api/organizations/app',
json={
'enable_proactive_conversation_starters': False,
'enable_solvability_analysis': True,
'max_budget_per_task': 25.0,
},
)
# Assert
assert response.status_code == status.HTTP_200_OK
response_data = response.json()
assert response_data['enable_proactive_conversation_starters'] is False
assert response_data['enable_solvability_analysis'] is True
assert response_data['max_budget_per_task'] == 25.0
mock_update.assert_called_once()
@pytest.mark.asyncio
async def test_update_org_app_settings_partial_update(
mock_app_with_get_user_id, mock_member_role
):
"""
GIVEN: Partial update data (only some fields)
WHEN: POST /api/organizations/app is called
THEN: Only specified fields are updated
"""
# Arrange
mock_response = OrgAppSettingsResponse(
enable_proactive_conversation_starters=False,
enable_solvability_analysis=True,
max_budget_per_task=10.0, # Unchanged
)
with (
patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_member_role),
),
patch(
'server.routes.orgs.OrgAppSettingsService.update_org_app_settings',
AsyncMock(return_value=mock_response),
) as mock_update,
):
client = TestClient(mock_app_with_get_user_id)
# Act - only updating one field
response = client.post(
'/api/organizations/app',
json={'enable_proactive_conversation_starters': False},
)
# Assert
assert response.status_code == status.HTTP_200_OK
mock_update.assert_called_once()
# Verify the update data only contains the specified field
call_args = mock_update.call_args
update_data = call_args[0][0] # First positional argument (update_data)
assert isinstance(update_data, OrgAppSettingsUpdate)
@pytest.mark.asyncio
async def test_update_org_app_settings_set_null(
mock_app_with_get_user_id, mock_member_role
):
"""
GIVEN: Request to set max_budget_per_task to null
WHEN: POST /api/organizations/app is called
THEN: The field is set to null successfully
"""
# Arrange
mock_response = OrgAppSettingsResponse(
enable_proactive_conversation_starters=True,
enable_solvability_analysis=True,
max_budget_per_task=None,
)
with (
patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_member_role),
),
patch(
'server.routes.orgs.OrgAppSettingsService.update_org_app_settings',
AsyncMock(return_value=mock_response),
),
):
client = TestClient(mock_app_with_get_user_id)
# Act - explicitly setting max_budget_per_task to null
response = client.post(
'/api/organizations/app',
json={'max_budget_per_task': None},
)
# Assert
assert response.status_code == status.HTTP_200_OK
response_data = response.json()
assert response_data['max_budget_per_task'] is None
@pytest.mark.asyncio
async def test_update_org_app_settings_invalid_max_budget(
mock_app_with_get_user_id, mock_member_role
):
"""
GIVEN: Invalid max_budget_per_task value (zero or negative)
WHEN: POST /api/organizations/app is called
THEN: 422 Validation error is returned
"""
# Arrange
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_member_role),
):
client = TestClient(mock_app_with_get_user_id)
# Act - negative value
response = client.post(
'/api/organizations/app',
json={'max_budget_per_task': -5.0},
)
# Assert
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_update_org_app_settings_zero_max_budget(
mock_app_with_get_user_id, mock_member_role
):
"""
GIVEN: max_budget_per_task is set to zero
WHEN: POST /api/organizations/app is called
THEN: 422 Validation error is returned (must be greater than 0)
"""
# Arrange
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_member_role),
):
client = TestClient(mock_app_with_get_user_id)
# Act - zero value
response = client.post(
'/api/organizations/app',
json={'max_budget_per_task': 0},
)
# Assert
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_update_org_app_settings_not_found(
mock_app_with_get_user_id, mock_member_role
):
"""
GIVEN: User has no current organization
WHEN: POST /api/organizations/app is called
THEN: 404 Not Found error is returned
"""
# Arrange
with (
patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_member_role),
),
patch(
'server.routes.orgs.OrgAppSettingsService.update_org_app_settings',
AsyncMock(side_effect=OrgNotFoundError('current')),
),
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.post(
'/api/organizations/app',
json={'enable_proactive_conversation_starters': False},
)
# Assert
assert response.status_code == status.HTTP_404_NOT_FOUND
assert 'not found' in response.json()['detail'].lower()
@pytest.mark.asyncio
async def test_update_org_app_settings_database_error(
mock_app_with_get_user_id, mock_member_role
):
"""
GIVEN: Database update fails
WHEN: POST /api/organizations/app is called
THEN: 500 Internal Server Error is returned
"""
# Arrange
with (
patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_member_role),
),
patch(
'server.routes.orgs.OrgAppSettingsService.update_org_app_settings',
AsyncMock(side_effect=Exception('Database connection failed')),
),
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.post(
'/api/organizations/app',
json={'enable_proactive_conversation_starters': False},
)
# Assert
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert 'unexpected error' in response.json()['detail'].lower()
@pytest.mark.asyncio
async def test_update_org_app_settings_user_not_member(mock_app_with_get_user_id):
"""
GIVEN: User is not a member of any organization
WHEN: POST /api/organizations/app is called
THEN: 403 Forbidden error is returned
"""
# Arrange - user has no role (not a member)
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=None),
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.post(
'/api/organizations/app',
json={'enable_proactive_conversation_starters': False},
)
# Assert
assert response.status_code == status.HTTP_403_FORBIDDEN
assert 'not a member' in response.json()['detail'].lower()

View File

@@ -0,0 +1,173 @@
"""
Unit tests for OrgAppSettingsService.
Tests the service layer for organization app settings operations.
"""
import uuid
from unittest.mock import AsyncMock, MagicMock
import pytest
from server.routes.org_models import (
OrgAppSettingsResponse,
OrgAppSettingsUpdate,
OrgNotFoundError,
)
from server.services.org_app_settings_service import OrgAppSettingsService
from storage.org import Org
@pytest.fixture
def user_id():
"""Create a test user ID."""
return str(uuid.uuid4())
@pytest.fixture
def mock_org():
"""Create a mock organization with app settings."""
org = MagicMock(spec=Org)
org.id = uuid.uuid4()
org.enable_proactive_conversation_starters = True
org.enable_solvability_analysis = False
org.max_budget_per_task = 25.0
return org
@pytest.fixture
def mock_store():
"""Create a mock OrgAppSettingsStore."""
return MagicMock()
@pytest.fixture
def mock_user_context(user_id):
"""Create a mock UserContext that returns the user_id."""
context = MagicMock()
context.get_user_id = AsyncMock(return_value=user_id)
return context
@pytest.mark.asyncio
async def test_get_org_app_settings_success(
user_id, mock_org, mock_store, mock_user_context
):
"""
GIVEN: A user's current organization exists
WHEN: get_org_app_settings is called
THEN: OrgAppSettingsResponse is returned with correct data
"""
# Arrange
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
# Act
result = await service.get_org_app_settings()
# Assert
assert isinstance(result, OrgAppSettingsResponse)
assert result.enable_proactive_conversation_starters is True
assert result.enable_solvability_analysis is False
assert result.max_budget_per_task == 25.0
mock_store.get_current_org_by_user_id.assert_called_once_with(user_id)
@pytest.mark.asyncio
async def test_get_org_app_settings_org_not_found(
user_id, mock_store, mock_user_context
):
"""
GIVEN: A user has no current organization
WHEN: get_org_app_settings is called
THEN: OrgNotFoundError is raised
"""
# Arrange
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
# Act & Assert
with pytest.raises(OrgNotFoundError) as exc_info:
await service.get_org_app_settings()
assert 'current' in str(exc_info.value)
@pytest.mark.asyncio
async def test_update_org_app_settings_success(
user_id, mock_org, mock_store, mock_user_context
):
"""
GIVEN: A user's current organization exists
WHEN: update_org_app_settings is called with new values
THEN: OrgAppSettingsResponse is returned with updated data
"""
# Arrange
mock_org.enable_proactive_conversation_starters = False
mock_org.max_budget_per_task = 50.0
update_data = OrgAppSettingsUpdate(
enable_proactive_conversation_starters=False,
max_budget_per_task=50.0,
)
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
mock_store.update_org_app_settings = AsyncMock(return_value=mock_org)
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
# Act
result = await service.update_org_app_settings(update_data)
# Assert
assert isinstance(result, OrgAppSettingsResponse)
assert result.enable_proactive_conversation_starters is False
assert result.max_budget_per_task == 50.0
mock_store.update_org_app_settings.assert_called_once_with(
org_id=mock_org.id, update_data=update_data
)
@pytest.mark.asyncio
async def test_update_org_app_settings_no_changes(
user_id, mock_org, mock_store, mock_user_context
):
"""
GIVEN: A user's current organization exists
WHEN: update_org_app_settings is called with no fields
THEN: Current settings are returned without calling update
"""
# Arrange
update_data = OrgAppSettingsUpdate() # No fields set
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
mock_store.update_org_app_settings = AsyncMock()
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
# Act
result = await service.update_org_app_settings(update_data)
# Assert
assert isinstance(result, OrgAppSettingsResponse)
mock_store.get_current_org_by_user_id.assert_called_once_with(user_id)
mock_store.update_org_app_settings.assert_not_called()
@pytest.mark.asyncio
async def test_update_org_app_settings_org_not_found(
user_id, mock_store, mock_user_context
):
"""
GIVEN: A user has no current organization
WHEN: update_org_app_settings is called
THEN: OrgNotFoundError is raised
"""
# Arrange
update_data = OrgAppSettingsUpdate(enable_proactive_conversation_starters=False)
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
# Act & Assert
with pytest.raises(OrgNotFoundError) as exc_info:
await service.update_org_app_settings(update_data)
assert 'current' in str(exc_info.value)

View File

@@ -0,0 +1,215 @@
"""
Unit tests for OrgLLMSettingsService.
Tests the service layer for organization LLM settings operations.
"""
import uuid
from unittest.mock import AsyncMock, MagicMock
import pytest
from server.routes.org_models import (
OrgLLMSettingsResponse,
OrgLLMSettingsUpdate,
OrgNotFoundError,
)
from server.services.org_llm_settings_service import OrgLLMSettingsService
from storage.org import Org
@pytest.fixture
def user_id():
"""Create a test user ID."""
return str(uuid.uuid4())
@pytest.fixture
def org_id():
"""Create a test org ID."""
return uuid.uuid4()
@pytest.fixture
def mock_org(org_id):
"""Create a mock organization with LLM settings."""
org = MagicMock(spec=Org)
org.id = org_id
org.default_llm_model = 'claude-3'
org.default_llm_base_url = 'https://api.anthropic.com'
org.search_api_key = None
org.agent = 'CodeActAgent'
org.confirmation_mode = True
org.security_analyzer = None
org.enable_default_condenser = True
org.condenser_max_size = None
org.default_max_iterations = 50
return org
@pytest.fixture
def mock_store():
"""Create a mock OrgLLMSettingsStore."""
return MagicMock()
@pytest.fixture
def mock_user_context(user_id):
"""Create a mock UserContext that returns the user_id."""
context = MagicMock()
context.get_user_id = AsyncMock(return_value=user_id)
return context
@pytest.mark.asyncio
async def test_get_org_llm_settings_success(
user_id, mock_org, mock_store, mock_user_context
):
"""
GIVEN: A user with a current organization
WHEN: get_org_llm_settings is called
THEN: OrgLLMSettingsResponse is returned with correct data
"""
# Arrange
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
# Act
result = await service.get_org_llm_settings()
# Assert
assert isinstance(result, OrgLLMSettingsResponse)
assert result.default_llm_model == 'claude-3'
assert result.agent == 'CodeActAgent'
mock_store.get_current_org_by_user_id.assert_called_once_with(user_id)
@pytest.mark.asyncio
async def test_get_org_llm_settings_user_not_authenticated(mock_store):
"""
GIVEN: A user is not authenticated
WHEN: get_org_llm_settings is called
THEN: ValueError is raised
"""
# Arrange
mock_user_context = MagicMock()
mock_user_context.get_user_id = AsyncMock(return_value=None)
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
# Act & Assert
with pytest.raises(ValueError) as exc_info:
await service.get_org_llm_settings()
assert 'not authenticated' in str(exc_info.value)
@pytest.mark.asyncio
async def test_get_org_llm_settings_org_not_found(
user_id, mock_store, mock_user_context
):
"""
GIVEN: A user has no current organization
WHEN: get_org_llm_settings is called
THEN: OrgNotFoundError is raised
"""
# Arrange
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
# Act & Assert
with pytest.raises(OrgNotFoundError) as exc_info:
await service.get_org_llm_settings()
assert 'No current organization' in str(exc_info.value)
@pytest.mark.asyncio
async def test_update_org_llm_settings_success(
user_id, mock_org, mock_store, mock_user_context
):
"""
GIVEN: A user with a current organization
WHEN: update_org_llm_settings is called with new values
THEN: OrgLLMSettingsResponse is returned with updated data
"""
# Arrange
updated_org = MagicMock(spec=Org)
updated_org.id = mock_org.id
updated_org.default_llm_model = 'new-model'
updated_org.default_llm_base_url = None
updated_org.search_api_key = None
updated_org.agent = 'CodeActAgent'
updated_org.confirmation_mode = False
updated_org.security_analyzer = None
updated_org.enable_default_condenser = True
updated_org.condenser_max_size = None
updated_org.default_max_iterations = 100
update_data = OrgLLMSettingsUpdate(
default_llm_model='new-model',
confirmation_mode=False,
default_max_iterations=100,
)
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
mock_store.update_org_llm_settings = AsyncMock(return_value=updated_org)
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
# Act
result = await service.update_org_llm_settings(update_data)
# Assert
assert isinstance(result, OrgLLMSettingsResponse)
assert result.default_llm_model == 'new-model'
assert result.confirmation_mode is False
assert result.default_max_iterations == 100
mock_store.update_org_llm_settings.assert_called_once_with(
org_id=mock_org.id,
update_data=update_data,
)
@pytest.mark.asyncio
async def test_update_org_llm_settings_no_changes(
user_id, mock_org, mock_store, mock_user_context
):
"""
GIVEN: A user with a current organization
WHEN: update_org_llm_settings is called with no fields
THEN: Current settings are returned without calling update
"""
# Arrange
update_data = OrgLLMSettingsUpdate() # No fields set
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
mock_store.update_org_llm_settings = AsyncMock()
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
# Act
result = await service.update_org_llm_settings(update_data)
# Assert
assert isinstance(result, OrgLLMSettingsResponse)
assert result.default_llm_model == 'claude-3'
mock_store.update_org_llm_settings.assert_not_called()
@pytest.mark.asyncio
async def test_update_org_llm_settings_org_not_found(
user_id, mock_store, mock_user_context
):
"""
GIVEN: A user has no current organization
WHEN: update_org_llm_settings is called
THEN: OrgNotFoundError is raised
"""
# Arrange
update_data = OrgLLMSettingsUpdate(default_llm_model='new-model')
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
# Act & Assert
with pytest.raises(OrgNotFoundError) as exc_info:
await service.update_org_llm_settings(update_data)
assert 'No current organization' in str(exc_info.value)

View File

@@ -399,3 +399,135 @@ class TestUpdateActiveWorkingSeconds:
assert conversation_work.seconds == 23.0
assert conversation_work.conversation_id == conversation_id
assert conversation_work.user_id == user_id
class TestInvokeConversationCallbacks:
"""Tests for invoke_conversation_callbacks function.
This function uses async database sessions (a_session_maker) to query
and invoke callbacks for a conversation.
"""
@pytest.fixture
def mock_observation(self):
"""Create a mock AgentStateChangedObservation."""
observation = Mock(spec=AgentStateChangedObservation)
observation.agent_state = AgentState.FINISHED
return observation
@pytest.fixture
def create_mock_async_session(self):
"""Factory to create properly mocked async session context manager."""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock
def _create(callbacks_list):
mock_session = Mock()
mock_result = Mock()
mock_result.scalars.return_value.all.return_value = callbacks_list
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock(return_value=None)
@asynccontextmanager
async def mock_context_manager():
yield mock_session
return mock_context_manager, mock_session
return _create
@pytest.mark.asyncio
async def test_invoke_callbacks_with_active_callbacks(
self, mock_observation, create_mock_async_session
):
"""Test that active callbacks are invoked successfully."""
from unittest.mock import AsyncMock
# Arrange
conversation_id = 'test_conversation_callbacks'
mock_processor = AsyncMock(return_value=None)
# Create a mock callback
mock_callback = Mock()
mock_callback.id = 1
mock_callback.processor_type = 'test_processor'
mock_callback.get_processor.return_value = mock_processor
mock_context_manager, mock_session = create_mock_async_session([mock_callback])
# Act
with patch(
'server.utils.conversation_callback_utils.a_session_maker',
mock_context_manager,
):
from server.utils.conversation_callback_utils import (
invoke_conversation_callbacks,
)
await invoke_conversation_callbacks(conversation_id, mock_observation)
# Assert
mock_callback.get_processor.assert_called_once()
mock_processor.assert_called_once_with(mock_callback, mock_observation)
@pytest.mark.asyncio
async def test_invoke_callbacks_with_no_active_callbacks(
self, mock_observation, create_mock_async_session
):
"""Test behavior when no active callbacks exist."""
# Arrange
conversation_id = 'test_no_callbacks'
mock_context_manager, mock_session = create_mock_async_session([])
# Act
with patch(
'server.utils.conversation_callback_utils.a_session_maker',
mock_context_manager,
):
from server.utils.conversation_callback_utils import (
invoke_conversation_callbacks,
)
await invoke_conversation_callbacks(conversation_id, mock_observation)
# Assert - should complete without errors
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_invoke_callbacks_handles_processor_exception(
self, mock_observation, create_mock_async_session
):
"""Test that processor exceptions are caught and callback status is updated."""
from unittest.mock import AsyncMock
# Arrange
conversation_id = 'test_callback_error'
mock_processor = AsyncMock(side_effect=Exception('Processor error'))
mock_callback = Mock()
mock_callback.id = 1
mock_callback.processor_type = 'failing_processor'
mock_callback.get_processor.return_value = mock_processor
mock_callback.status = 'active'
mock_context_manager, mock_session = create_mock_async_session([mock_callback])
# Act
with patch(
'server.utils.conversation_callback_utils.a_session_maker',
mock_context_manager,
), patch('server.utils.conversation_callback_utils.logger') as mock_logger:
from server.utils.conversation_callback_utils import (
invoke_conversation_callbacks,
)
from storage.conversation_callback import CallbackStatus
await invoke_conversation_callbacks(conversation_id, mock_observation)
# Assert - callback status should be set to ERROR
assert mock_callback.status == CallbackStatus.ERROR
mock_logger.error.assert_called_once()
error_call = mock_logger.error.call_args
assert error_call[0][0] == 'callback_invocation_failed'

View File

@@ -1,127 +1,127 @@
"""Unit tests for AuthTokenStore."""
"""Unit tests for AuthTokenStore using SQLite in-memory database."""
import time
from contextlib import asynccontextmanager
from typing import Dict
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import patch
import pytest
from server.auth.auth_error import TokenRefreshError
from sqlalchemy.exc import OperationalError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.auth_token_store import (
ACCESS_TOKEN_EXPIRY_BUFFER,
LOCK_TIMEOUT_SECONDS,
AuthTokenStore,
)
from storage.auth_tokens import AuthTokens
from storage.base import Base
from openhands.integrations.service_types import ProviderType
def create_mock_session():
"""Create a mock async session with properly configured context managers."""
session = AsyncMock()
# Create async context manager for begin()
@asynccontextmanager
async def begin_context():
yield
session.begin = begin_context
return session
def create_mock_session_maker(mock_session):
"""Create a mock async session maker."""
@asynccontextmanager
async def session_context():
yield mock_session
# Return a callable that returns the context manager
return lambda: session_context()
@pytest.fixture
def mock_session():
"""Create mock async session."""
return create_mock_session()
@pytest.fixture
def mock_session_maker(mock_session):
"""Create mock async session maker."""
return create_mock_session_maker(mock_session)
@pytest.fixture
def auth_token_store(mock_session_maker):
"""Create AuthTokenStore instance with mocked session maker."""
return AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
)
return engine
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker bound to the async engine."""
async_session_maker = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
# Create all tables
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
return async_session_maker
class TestIsTokenExpired:
"""Tests for _is_token_expired method."""
def test_both_tokens_valid(self, auth_token_store):
def test_both_tokens_valid(self):
"""Test when both tokens are valid (not expired)."""
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
current_time = int(time.time())
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
refresh_expires = current_time + 1000
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expired, refresh_expired = store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is False
assert refresh_expired is False
def test_access_token_expired(self, auth_token_store):
def test_access_token_expired(self):
"""Test when access token is expired but within buffer."""
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
current_time = int(time.time())
# Access token expires within buffer period
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER - 100
refresh_expires = current_time + 10000
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expired, refresh_expired = store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is True
assert refresh_expired is False
def test_refresh_token_expired(self, auth_token_store):
def test_refresh_token_expired(self):
"""Test when refresh token is expired."""
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
current_time = int(time.time())
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
refresh_expires = current_time - 100 # Already expired
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expired, refresh_expired = store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is False
assert refresh_expired is True
def test_both_tokens_expired(self, auth_token_store):
def test_both_tokens_expired(self):
"""Test when both tokens are expired."""
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
current_time = int(time.time())
access_expires = current_time - 100
refresh_expires = current_time - 100
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expired, refresh_expired = store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is True
assert refresh_expired is True
def test_zero_expiration_treated_as_never_expires(self, auth_token_store):
def test_zero_expiration_treated_as_never_expires(self):
"""Test that 0 expiration time is treated as never expires."""
access_expired, refresh_expired = auth_token_store._is_token_expired(0, 0)
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
access_expired, refresh_expired = store._is_token_expired(0, 0)
assert access_expired is False
assert refresh_expired is False
@@ -131,427 +131,188 @@ class TestLoadTokensFastPath:
"""Tests for load_tokens fast path (no lock needed)."""
@pytest.mark.asyncio
async def test_fast_path_token_not_found(
self, auth_token_store, mock_session_maker, mock_session
):
async def test_fast_path_token_not_found(self, async_session_maker):
"""Test fast path returns None when no token record exists."""
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
result = await auth_token_store.load_tokens()
result = await store.load_tokens()
assert result is None
assert result is None
@pytest.mark.asyncio
async def test_fast_path_valid_token_no_refresh_needed(
self, auth_token_store, mock_session_maker, mock_session
):
async def test_fast_path_valid_token_no_refresh_needed(self, async_session_maker):
"""Test fast path returns tokens when they are still valid."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'valid-access-token'
mock_token.refresh_token = 'valid-refresh-token'
mock_token.access_token_expires_at = (
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
)
mock_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
# First, store a valid token in the database
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
result = await auth_token_store.load_tokens()
await store.store_tokens(
access_token='valid-access-token',
refresh_token='valid-refresh-token',
access_token_expires_at=current_time
+ ACCESS_TOKEN_EXPIRY_BUFFER
+ 1000,
refresh_token_expires_at=current_time + 10000,
)
assert result is not None
assert result['access_token'] == 'valid-access-token'
assert result['refresh_token'] == 'valid-refresh-token'
# Now load tokens - should return valid tokens without refresh
result = await store.load_tokens()
assert result is not None
assert result['access_token'] == 'valid-access-token'
assert result['refresh_token'] == 'valid-refresh-token'
@pytest.mark.asyncio
async def test_fast_path_no_refresh_callback_provided(
self, auth_token_store, mock_session_maker, mock_session
):
async def test_fast_path_no_refresh_callback_provided(self, async_session_maker):
"""Test fast path returns existing tokens when no refresh callback is provided."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'expired-access-token'
mock_token.refresh_token = 'valid-refresh-token'
# Expired access token
mock_token.access_token_expires_at = current_time - 100
mock_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
# Store expired access token
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
result = await auth_token_store.load_tokens(check_expiration_and_refresh=None)
await store.store_tokens(
access_token='expired-access-token',
refresh_token='valid-refresh-token',
access_token_expires_at=current_time - 100, # Expired
refresh_token_expires_at=current_time + 10000,
)
assert result is not None
assert result['access_token'] == 'expired-access-token'
# Load without refresh callback - should still return tokens
result = await store.load_tokens(check_expiration_and_refresh=None)
assert result is not None
assert result['access_token'] == 'expired-access-token'
class TestLoadTokensSlowPath:
"""Tests for load_tokens slow path (lock required for refresh)."""
"""Tests for load_tokens slow path (lock required for refresh).
Note: These tests require PostgreSQL's lock_timeout feature which is not
available in SQLite. The slow path tests are skipped when using SQLite.
"""
@pytest.mark.skip(reason='SQLite does not support PostgreSQL lock_timeout syntax')
@pytest.mark.asyncio
async def test_slow_path_successful_refresh(self):
async def test_slow_path_successful_refresh(self, async_session_maker):
"""Test slow path successfully refreshes expired tokens."""
current_time = int(time.time())
mock_session = create_mock_session()
pass
# First call (fast path) - returns expired token
# Second call (slow path) - returns same token for update
expired_token = MagicMock()
expired_token.id = 1
expired_token.access_token = 'expired-access-token'
expired_token.refresh_token = 'valid-refresh-token'
expired_token.access_token_expires_at = current_time - 100 # Expired
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
) -> Dict[str, str | int]:
return {
'access_token': 'new-access-token',
'refresh_token': 'new-refresh-token',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
assert result is not None
assert result['access_token'] == 'new-access-token'
assert result['refresh_token'] == 'new-refresh-token'
@pytest.mark.skip(reason='SQLite does not support PostgreSQL lock_timeout syntax')
@pytest.mark.asyncio
async def test_refresh_callback_returns_none(self, async_session_maker):
"""Test behavior when refresh callback returns None (no refresh performed)."""
pass
@pytest.mark.asyncio
async def test_slow_path_double_check_avoids_refresh(self):
"""Test double-check locking: token was refreshed by another request."""
async def test_slow_path_double_check_avoids_refresh(self, async_session_maker):
"""Test double-check pattern avoids unnecessary refresh."""
current_time = int(time.time())
mock_session = create_mock_session()
# Simulate scenario:
# 1. Fast path sees expired token
# 2. While waiting for lock, another request refreshes
# 3. Slow path sees fresh token, skips refresh
call_count = [0]
def create_token():
call_count[0] += 1
token = MagicMock()
token.id = 1
token.access_token = 'fresh-access-token'
token.refresh_token = 'fresh-refresh-token'
if call_count[0] == 1:
# First call (fast path) - expired
token.access_token_expires_at = current_time - 100
else:
# Second call (slow path) - already refreshed
token.access_token_expires_at = (
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
)
token.refresh_token_expires_at = current_time + 86400
return token
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.side_effect = (
lambda: create_token()
)
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
refresh_called = [False]
async def mock_refresh(
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
) -> Dict[str, str | int]:
refresh_called[0] = True
return {
'access_token': 'should-not-be-used',
'refresh_token': 'should-not-be-used',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
# The refresh callback should not be called because double-check
# found the token was already refreshed
assert result is not None
assert result['access_token'] == 'fresh-access-token'
@pytest.mark.asyncio
async def test_slow_path_token_not_found_after_lock(self):
"""Test slow path returns None if token record disappears after lock."""
current_time = int(time.time())
mock_session = create_mock_session()
# First call (fast path) - token exists but expired
# Second call (slow path with lock) - token no longer exists
call_count = [0]
def get_token():
call_count[0] += 1
if call_count[0] == 1:
token = MagicMock()
token.access_token_expires_at = current_time - 100 # Expired
token.refresh_token_expires_at = current_time + 10000
return token
return None
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.side_effect = get_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(*args) -> Dict[str, str | int]:
return {
'access_token': 'new-token',
'refresh_token': 'new-refresh',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
assert result is None
class TestLoadTokensLockTimeout:
"""Tests for lock timeout handling."""
@pytest.mark.asyncio
async def test_lock_timeout_raises_token_refresh_error(self):
"""Test that lock timeout raises TokenRefreshError."""
current_time = int(time.time())
mock_session = create_mock_session()
# First call (fast path) - returns expired token
expired_token = MagicMock()
expired_token.access_token_expires_at = current_time - 100
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
# First execute for fast path succeeds
# Second execute (for slow path) raises OperationalError
call_count = [0]
async def execute_side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] <= 1:
return mock_result
# Simulate lock timeout
raise OperationalError(
'canceling statement due to lock timeout', None, None
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
mock_session.execute = execute_side_effect
# Store a token that will be valid when second check happens
await store.store_tokens(
access_token='original-access-token',
refresh_token='valid-refresh-token',
access_token_expires_at=current_time
+ ACCESS_TOKEN_EXPIRY_BUFFER
+ 1000,
refresh_token_expires_at=current_time + 10000,
)
mock_session_maker = create_mock_session_maker(mock_session)
# Load with refresh callback - should NOT refresh since token is valid
result = await store.load_tokens()
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(*args) -> Dict[str, str | int]:
return {
'access_token': 'new-token',
'refresh_token': 'new-refresh',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
with pytest.raises(TokenRefreshError) as exc_info:
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
assert 'lock timeout' in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_lock_timeout_preserves_original_exception(self):
"""Test that TokenRefreshError preserves the original OperationalError."""
current_time = int(time.time())
mock_session = create_mock_session()
expired_token = MagicMock()
expired_token.access_token_expires_at = current_time - 100
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
original_error = OperationalError(
'canceling statement due to lock timeout', None, None
)
call_count = [0]
async def execute_side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] <= 1:
return mock_result
raise original_error
mock_session.execute = execute_side_effect
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(*args) -> Dict[str, str | int]:
return {
'access_token': 'new-token',
'refresh_token': 'new-refresh',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
with pytest.raises(TokenRefreshError) as exc_info:
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
# Verify the original exception is chained
assert exc_info.value.__cause__ is original_error
class TestLoadTokensRefreshCallbackBehavior:
"""Tests for refresh callback return values."""
@pytest.mark.asyncio
async def test_refresh_callback_returns_none(self):
"""Test behavior when refresh callback returns None (no refresh performed)."""
current_time = int(time.time())
mock_session = create_mock_session()
expired_token = MagicMock()
expired_token.id = 1
expired_token.access_token = 'old-access-token'
expired_token.refresh_token = 'old-refresh-token'
expired_token.access_token_expires_at = current_time - 100 # Expired
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh_returns_none(
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
) -> Dict[str, str | int] | None:
return None
result = await auth_store.load_tokens(
check_expiration_and_refresh=mock_refresh_returns_none
)
# Should return the old tokens when refresh returns None
assert result is not None
assert result['access_token'] == 'old-access-token'
assert result['refresh_token'] == 'old-refresh-token'
assert result is not None
assert result['access_token'] == 'original-access-token'
class TestStoreTokens:
"""Tests for store_tokens method."""
@pytest.mark.asyncio
async def test_store_tokens_creates_new_record(self):
async def test_store_tokens_creates_new_record(self, async_session_maker):
"""Test storing tokens when no existing record."""
mock_session = create_mock_session()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
mock_session_maker = create_mock_session_maker(mock_session)
await store.store_tokens(
access_token='new-access-token',
refresh_token='new-refresh-token',
access_token_expires_at=1234567890,
refresh_token_expires_at=1234657890,
)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
await auth_store.store_tokens(
access_token='new-access-token',
refresh_token='new-refresh-token',
access_token_expires_at=1234567890,
refresh_token_expires_at=1234657890,
)
mock_session.add.assert_called_once()
# Verify the token was stored
async with async_session_maker() as session:
result = await session.execute(
select(AuthTokens).where(
AuthTokens.keycloak_user_id == 'test-user-123',
AuthTokens.identity_provider == ProviderType.GITHUB.value,
)
)
token_record = result.scalars().first()
assert token_record is not None
assert token_record.access_token == 'new-access-token'
assert token_record.refresh_token == 'new-refresh-token'
@pytest.mark.asyncio
async def test_store_tokens_updates_existing_record(self):
async def test_store_tokens_updates_existing_record(self, async_session_maker):
"""Test storing tokens updates existing record."""
mock_session = create_mock_session()
existing_token = MagicMock()
existing_token.access_token = 'old-access'
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = existing_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
# First, create a token record
await store.store_tokens(
access_token='old-access-token',
refresh_token='old-refresh-token',
access_token_expires_at=1234567890,
refresh_token_expires_at=1234657890,
)
mock_session_maker = create_mock_session_maker(mock_session)
# Now update it
await store.store_tokens(
access_token='new-access-token',
refresh_token='new-refresh-token',
access_token_expires_at=1234567891,
refresh_token_expires_at=1234657891,
)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
await auth_store.store_tokens(
access_token='new-access-token',
refresh_token='new-refresh-token',
access_token_expires_at=1234567890,
refresh_token_expires_at=1234657890,
)
assert existing_token.access_token == 'new-access-token'
assert existing_token.refresh_token == 'new-refresh-token'
# Verify the token was updated
async with async_session_maker() as session:
result = await session.execute(
select(AuthTokens).where(
AuthTokens.keycloak_user_id == 'test-user-123',
AuthTokens.identity_provider == ProviderType.GITHUB.value,
)
)
token_record = result.scalars().first()
assert token_record is not None
assert token_record.access_token == 'new-access-token'
assert token_record.refresh_token == 'new-refresh-token'
class TestIsAccessTokenValid:
@@ -559,80 +320,93 @@ class TestIsAccessTokenValid:
@pytest.mark.asyncio
async def test_is_access_token_valid_returns_false_when_no_tokens(
self, auth_token_store, mock_session_maker, mock_session
self, async_session_maker
):
"""Test returns False when no tokens found."""
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
result = await auth_token_store.is_access_token_valid()
result = await store.is_access_token_valid()
assert result is False
assert result is False
@pytest.mark.asyncio
async def test_is_access_token_valid_returns_true_for_valid_token(
self, auth_token_store, mock_session_maker, mock_session
self, async_session_maker
):
"""Test returns True when token is valid."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'valid-access'
mock_token.refresh_token = 'valid-refresh'
mock_token.access_token_expires_at = current_time + 1000
mock_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
result = await auth_token_store.is_access_token_valid()
await store.store_tokens(
access_token='valid-access',
refresh_token='valid-refresh',
access_token_expires_at=current_time + 1000,
refresh_token_expires_at=current_time + 10000,
)
assert result is True
result = await store.is_access_token_valid()
assert result is True
@pytest.mark.asyncio
async def test_is_access_token_valid_returns_false_for_expired_token(
self, auth_token_store, mock_session_maker, mock_session
self, async_session_maker
):
"""Test returns False when token is expired."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'expired-access'
mock_token.refresh_token = 'valid-refresh'
mock_token.access_token_expires_at = current_time - 100 # Expired
mock_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
result = await auth_token_store.is_access_token_valid()
await store.store_tokens(
access_token='expired-access',
refresh_token='valid-refresh',
access_token_expires_at=current_time - 100, # Expired
refresh_token_expires_at=current_time + 10000,
)
assert result is False
result = await store.is_access_token_valid()
assert result is False
class TestGetInstance:
"""Tests for get_instance class method."""
@pytest.mark.asyncio
async def test_get_instance_creates_auth_token_store(self):
async def test_get_instance_creates_auth_token_store(self, async_session_maker):
"""Test get_instance creates an AuthTokenStore with correct params."""
with patch('storage.auth_token_store.a_session_maker') as mock_a_session_maker:
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = await AuthTokenStore.get_instance(
keycloak_user_id='user-123', idp=ProviderType.GITHUB
)
assert store.keycloak_user_id == 'user-123'
assert store.idp == ProviderType.GITHUB
assert store.a_session_maker is mock_a_session_maker
class TestIdentityProviderValue:
"""Tests for identity_provider_value property."""
def test_identity_provider_value_returns_idp_value(self, auth_token_store):
def test_identity_provider_value_returns_idp_value(self):
"""Test that identity_provider_value returns the enum value."""
assert auth_token_store.identity_provider_value == ProviderType.GITHUB.value
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
assert store.identity_provider_value == ProviderType.GITHUB.value
def test_identity_provider_value_for_different_providers(self):
"""Test identity_provider_value for different providers."""
@@ -644,7 +418,6 @@ class TestIdentityProviderValue:
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=provider,
a_session_maker=MagicMock(),
)
assert store.identity_provider_value == provider.value

View File

@@ -1,33 +1,17 @@
"""Unit tests for DeviceCodeStore."""
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from sqlalchemy.exc import IntegrityError
from sqlalchemy import select
from storage.device_code import DeviceCode
from storage.device_code_store import DeviceCodeStore
@pytest.fixture
def mock_session():
"""Mock database session."""
session = MagicMock()
return session
@pytest.fixture
def mock_session_maker(mock_session):
"""Mock session maker."""
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = mock_session
session_maker.return_value.__exit__.return_value = None
return session_maker
@pytest.fixture
def device_code_store(mock_session_maker):
def device_code_store():
"""Create DeviceCodeStore instance."""
return DeviceCodeStore(mock_session_maker)
return DeviceCodeStore()
class TestDeviceCodeStore:
@@ -49,145 +33,257 @@ class TestDeviceCodeStore:
assert len(code) == 128
assert code.isalnum()
def test_create_device_code_success(self, device_code_store, mock_session):
@pytest.mark.asyncio
async def test_create_device_code_success(
self, device_code_store, async_session_maker
):
"""Test successful device code creation."""
# Mock successful creation (no IntegrityError)
mock_device_code = MagicMock(spec=DeviceCode)
mock_device_code.device_code = 'test-device-code-123'
mock_device_code.user_code = 'TESTCODE'
# Mock the session to return our mock device code after refresh
def mock_refresh(obj):
obj.device_code = mock_device_code.device_code
obj.user_code = mock_device_code.user_code
mock_session.refresh.side_effect = mock_refresh
result = device_code_store.create_device_code(expires_in=600)
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.create_device_code(expires_in=600)
assert isinstance(result, DeviceCode)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once()
mock_session.expunge.assert_called_once()
assert len(result.device_code) == 128
assert len(result.user_code) == 8
def test_create_device_code_with_retries(
self, device_code_store, mock_session_maker
# Verify the DeviceCode was created in the database
async with async_session_maker() as session:
result_db = await session.execute(
select(DeviceCode).filter(DeviceCode.device_code == result.device_code)
)
device_code = result_db.scalars().first()
assert device_code is not None
assert device_code.user_code == result.user_code
@pytest.mark.asyncio
async def test_create_device_code_with_retries(
self, device_code_store, async_session_maker
):
"""Test device code creation with constraint violation retries."""
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
# First create a device code to cause a collision
with patch('storage.device_code_store.a_session_maker', async_session_maker):
first_code = await device_code_store.create_device_code(expires_in=600)
# First attempt fails with IntegrityError, second succeeds
mock_session.commit.side_effect = [IntegrityError('', '', ''), None]
# Patch generate methods to return the same codes on first attempt,
# then different codes on second attempt
call_count = {'user': 0, 'device': 0}
original_generate_user_code = device_code_store.generate_user_code
original_generate_device_code = device_code_store.generate_device_code
mock_device_code = MagicMock(spec=DeviceCode)
mock_device_code.device_code = 'test-device-code-456'
mock_device_code.user_code = 'TESTCD2'
def mock_generate_user_code():
call_count['user'] += 1
if call_count['user'] == 1:
return first_code.user_code # Collision
return original_generate_user_code()
def mock_refresh(obj):
obj.device_code = mock_device_code.device_code
obj.user_code = mock_device_code.user_code
def mock_generate_device_code():
call_count['device'] += 1
if call_count['device'] == 1:
return first_code.device_code # Collision
return original_generate_device_code()
mock_session.refresh.side_effect = mock_refresh
device_code_store.generate_user_code = mock_generate_user_code
device_code_store.generate_device_code = mock_generate_device_code
store = DeviceCodeStore(mock_session_maker)
result = store.create_device_code(expires_in=600)
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.create_device_code(expires_in=600)
assert isinstance(result, DeviceCode)
assert mock_session.add.call_count == 2 # Two attempts
assert mock_session.commit.call_count == 2 # Two attempts
assert result.device_code != first_code.device_code # Should be different
assert call_count['user'] == 2 # Two attempts
def test_create_device_code_max_attempts_exceeded(
self, device_code_store, mock_session_maker
@pytest.mark.asyncio
async def test_create_device_code_max_attempts_exceeded(
self, device_code_store, async_session_maker
):
"""Test device code creation failure after max attempts."""
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
# First create a device code
with patch('storage.device_code_store.a_session_maker', async_session_maker):
first_code = await device_code_store.create_device_code(expires_in=600)
# All attempts fail with IntegrityError
mock_session.commit.side_effect = IntegrityError('', '', '')
# Always return the same codes to cause repeated collisions
device_code_store.generate_user_code = lambda: first_code.user_code
device_code_store.generate_device_code = lambda: first_code.device_code
store = DeviceCodeStore(mock_session_maker)
with patch('storage.device_code_store.a_session_maker', async_session_maker):
with pytest.raises(
RuntimeError,
match='Failed to generate unique device codes after 3 attempts',
):
await device_code_store.create_device_code(
expires_in=600, max_attempts=3
)
with pytest.raises(
RuntimeError,
match='Failed to generate unique device codes after 3 attempts',
):
store.create_device_code(expires_in=600, max_attempts=3)
@pytest.mark.asyncio
async def test_get_by_device_code(self, device_code_store, async_session_maker):
"""Test getting device code by device code."""
# Create a device code first
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
result = await device_code_store.get_by_device_code(created.device_code)
@pytest.mark.parametrize(
'lookup_method,lookup_field',
[
('get_by_device_code', 'device_code'),
('get_by_user_code', 'user_code'),
],
)
def test_lookup_methods(
self, device_code_store, mock_session, lookup_method, lookup_field
assert result is not None
assert result.device_code == created.device_code
assert result.user_code == created.user_code
@pytest.mark.asyncio
async def test_get_by_device_code_not_found(
self, device_code_store, async_session_maker
):
"""Test device code lookup methods."""
test_code = 'test-code-123'
mock_device_code = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = (
mock_device_code
)
"""Test getting non-existent device code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.get_by_device_code('non-existent-code')
result = getattr(device_code_store, lookup_method)(test_code)
assert result is None
assert result == mock_device_code
mock_session.query.assert_called_once_with(DeviceCode)
mock_session.query.return_value.filter_by.assert_called_once_with(
**{lookup_field: test_code}
)
@pytest.mark.asyncio
async def test_get_by_user_code(self, device_code_store, async_session_maker):
"""Test getting device code by user code."""
# Create a device code first
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
result = await device_code_store.get_by_user_code(created.user_code)
@pytest.mark.parametrize(
'device_exists,is_pending,expected_result',
[
(True, True, True), # Success case
(False, True, False), # Device not found
(True, False, False), # Device not pending
],
)
def test_authorize_device_code(
self,
device_code_store,
mock_session,
device_exists,
is_pending,
expected_result,
assert result is not None
assert result.device_code == created.device_code
assert result.user_code == created.user_code
@pytest.mark.asyncio
async def test_get_by_user_code_not_found(
self, device_code_store, async_session_maker
):
"""Test device code authorization."""
user_code = 'ABC12345'
"""Test getting non-existent user code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.get_by_user_code('NOTFOUND')
assert result is None
@pytest.mark.asyncio
async def test_authorize_device_code_success(
self, device_code_store, async_session_maker
):
"""Test successful device code authorization."""
user_id = 'test-user-123'
if device_exists:
mock_device = MagicMock()
mock_device.is_pending.return_value = is_pending
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_device
else:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
result = device_code_store.authorize_device_code(user_code, user_id)
assert result == expected_result
if expected_result:
mock_device.authorize.assert_called_once_with(user_id)
mock_session.commit.assert_called_once()
def test_deny_device_code(self, device_code_store, mock_session):
"""Test device code denial."""
user_code = 'ABC12345'
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_session.query.return_value.filter_by.return_value.first.return_value = (
mock_device
)
result = device_code_store.deny_device_code(user_code)
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
result = await device_code_store.authorize_device_code(
created.user_code, user_id
)
assert result is True
mock_device.deny.assert_called_once()
mock_session.commit.assert_called_once()
# Verify the device code was authorized in the database
async with async_session_maker() as session:
result_db = await session.execute(
select(DeviceCode).filter(DeviceCode.user_code == created.user_code)
)
device_code = result_db.scalars().first()
assert device_code.status == 'authorized'
assert device_code.keycloak_user_id == user_id
@pytest.mark.asyncio
async def test_authorize_device_code_not_found(
self, device_code_store, async_session_maker
):
"""Test authorizing non-existent device code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.authorize_device_code(
'NOTFOUND', 'user-123'
)
assert result is False
@pytest.mark.asyncio
async def test_authorize_device_code_not_pending(
self, device_code_store, async_session_maker
):
"""Test authorizing already authorized device code."""
user_id = 'test-user-123'
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
# First authorization
await device_code_store.authorize_device_code(created.user_code, user_id)
# Second authorization should fail
result = await device_code_store.authorize_device_code(
created.user_code, 'another-user'
)
assert result is False
@pytest.mark.asyncio
async def test_deny_device_code_success(
self, device_code_store, async_session_maker
):
"""Test successful device code denial."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
result = await device_code_store.deny_device_code(created.user_code)
assert result is True
# Verify the device code was denied in the database
async with async_session_maker() as session:
result_db = await session.execute(
select(DeviceCode).filter(DeviceCode.user_code == created.user_code)
)
device_code = result_db.scalars().first()
assert device_code.status == 'denied'
@pytest.mark.asyncio
async def test_deny_device_code_not_found(
self, device_code_store, async_session_maker
):
"""Test denying non-existent device code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.deny_device_code('NOTFOUND')
assert result is False
@pytest.mark.asyncio
async def test_deny_device_code_not_pending(
self, device_code_store, async_session_maker
):
"""Test denying already denied device code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
# First denial
await device_code_store.deny_device_code(created.user_code)
# Second denial should fail
result = await device_code_store.deny_device_code(created.user_code)
assert result is False
@pytest.mark.asyncio
async def test_update_poll_time_success(
self, device_code_store, async_session_maker
):
"""Test updating poll time."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
original_interval = created.current_interval
result = await device_code_store.update_poll_time(
created.device_code, increase_interval=True
)
assert result is True
# Verify the poll time was updated
async with async_session_maker() as session:
result_db = await session.execute(
select(DeviceCode).filter(DeviceCode.device_code == created.device_code)
)
device_code = result_db.scalars().first()
assert device_code.current_interval > original_interval
@pytest.mark.asyncio
async def test_update_poll_time_not_found(
self, device_code_store, async_session_maker
):
"""Test updating poll time for non-existent device code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.update_poll_time(
'non-existent-code', increase_interval=False
)
assert result is False

View File

@@ -9,16 +9,35 @@ from storage.base import Base
from storage.gitlab_webhook import GitlabWebhook
from storage.gitlab_webhook_store import GitlabWebhookStore
# Use module-scoped engine to share database across fixtures
_test_engine = None
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
@pytest.fixture(scope='function')
def event_loop():
"""Create an instance of the default event loop for each test case."""
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope='function')
async def async_engine(event_loop):
"""Create an async SQLite engine for testing.
This fixture creates an in-memory SQLite database and ensures
all tables are created before tests run.
"""
global _test_engine
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
_test_engine = engine
# Create all tables
async with engine.begin() as conn:
@@ -29,7 +48,7 @@ async def async_engine():
await engine.dispose()
@pytest.fixture
@pytest.fixture(scope='function')
async def async_session_maker(async_engine):
"""Create an async session maker for testing."""
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
@@ -37,8 +56,21 @@ async def async_session_maker(async_engine):
@pytest.fixture
async def webhook_store(async_session_maker):
"""Create a GitlabWebhookStore instance for testing."""
return GitlabWebhookStore(a_session_maker=async_session_maker)
"""Create a GitlabWebhookStore instance for testing.
This fixture injects the test's async_session_maker to ensure
the store uses the same in-memory database as the test fixtures.
"""
# Import here to avoid circular imports
store = GitlabWebhookStore()
# Inject the test session maker - this needs to replace the module-level import
import storage.gitlab_webhook_store as store_module
store_module.a_session_maker = async_session_maker
return store
@pytest.fixture
@@ -102,7 +134,7 @@ class TestGetWebhookByResourceOnly:
@pytest.mark.asyncio
async def test_get_project_webhook_by_resource_only(
self, webhook_store, async_session_maker, sample_webhooks
self, webhook_store, sample_webhooks
):
"""Test getting a project webhook by resource ID without user_id filter."""
# Arrange

View File

@@ -0,0 +1,232 @@
"""
Tests for JiraIntegrationStore async methods.
The store uses async database sessions (a_session_maker) for all operations,
which is critical for avoiding asyncpg event loop issues when called from
FastAPI async endpoints.
"""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, Mock, patch
import pytest
from storage.jira_integration_store import JiraIntegrationStore
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
@pytest.fixture
def store():
"""Create a JiraIntegrationStore instance."""
return JiraIntegrationStore()
@pytest.fixture
def create_mock_async_session():
"""Factory to create properly mocked async session context manager."""
def _create(query_result=None, all_results=None):
mock_session = Mock()
mock_result = Mock()
if all_results is not None:
mock_result.scalars.return_value.all.return_value = all_results
else:
mock_result.scalars.return_value.first.return_value = query_result
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.add = Mock()
mock_session.commit = AsyncMock()
mock_session.refresh = AsyncMock()
@asynccontextmanager
async def mock_context_manager():
yield mock_session
return mock_context_manager, mock_session
return _create
class TestJiraIntegrationStoreAsyncMethods:
"""Tests verifying JiraIntegrationStore methods use async sessions correctly."""
@pytest.mark.asyncio
async def test_get_workspace_by_id_returns_workspace(
self, store, create_mock_async_session
):
"""Test get_workspace_by_id returns workspace when found."""
# Arrange
mock_workspace = Mock(spec=JiraWorkspace)
mock_workspace.id = 1
mock_workspace.name = 'test-workspace'
mock_context_manager, mock_session = create_mock_async_session(mock_workspace)
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
result = await store.get_workspace_by_id(1)
# Assert
assert result == mock_workspace
mock_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_workspace_by_id_returns_none_when_not_found(
self, store, create_mock_async_session
):
"""Test get_workspace_by_id returns None when workspace not found."""
# Arrange
mock_context_manager, mock_session = create_mock_async_session(None)
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
result = await store.get_workspace_by_id(999)
# Assert
assert result is None
@pytest.mark.asyncio
async def test_get_workspace_by_name_normalizes_to_lowercase(
self, store, create_mock_async_session
):
"""Test get_workspace_by_name converts name to lowercase for query."""
# Arrange
mock_workspace = Mock(spec=JiraWorkspace)
mock_workspace.name = 'test-workspace'
mock_context_manager, mock_session = create_mock_async_session(mock_workspace)
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
result = await store.get_workspace_by_name('TEST-WORKSPACE')
# Assert
assert result == mock_workspace
# Verify the query was executed (filter includes lowercase conversion)
mock_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_active_user_filters_by_status(
self, store, create_mock_async_session
):
"""Test get_active_user only returns users with active status."""
# Arrange
mock_user = Mock(spec=JiraUser)
mock_user.jira_user_id = 'jira-123'
mock_user.jira_workspace_id = 1
mock_user.status = 'active'
mock_context_manager, mock_session = create_mock_async_session(mock_user)
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
result = await store.get_active_user('jira-123', 1)
# Assert
assert result == mock_user
mock_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_create_workspace_adds_and_commits(
self, store, create_mock_async_session
):
"""Test create_workspace properly adds, commits, and refreshes."""
# Arrange
mock_context_manager, mock_session = create_mock_async_session(None)
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
await store.create_workspace(
name='TEST-WORKSPACE',
jira_cloud_id='cloud-123',
admin_user_id='admin-user',
encrypted_webhook_secret='encrypted-secret',
svc_acc_email='svc@test.com',
encrypted_svc_acc_api_key='encrypted-key',
status='active',
)
# Assert
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once()
# Verify workspace was created with lowercase name
added_workspace = mock_session.add.call_args[0][0]
assert added_workspace.name == 'test-workspace'
@pytest.mark.asyncio
async def test_update_user_integration_status_raises_if_not_found(
self, store, create_mock_async_session
):
"""Test update_user_integration_status raises ValueError if user not found."""
# Arrange
mock_context_manager, mock_session = create_mock_async_session(None)
# Act & Assert
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
with pytest.raises(ValueError) as exc_info:
await store.update_user_integration_status('unknown-user', 'inactive')
assert 'Jira user not found' in str(exc_info.value)
@pytest.mark.asyncio
async def test_deactivate_workspace_deactivates_all_users(
self, store, create_mock_async_session
):
"""Test deactivate_workspace sets all users and workspace to inactive."""
# Arrange
mock_user1 = Mock(spec=JiraUser)
mock_user1.status = 'active'
mock_user2 = Mock(spec=JiraUser)
mock_user2.status = 'active'
mock_workspace = Mock(spec=JiraWorkspace)
mock_workspace.status = 'active'
mock_session = Mock()
# First execute returns users, second returns workspace
call_count = [0]
def execute_side_effect(*args, **kwargs):
result = Mock()
if call_count[0] == 0:
result.scalars.return_value.all.return_value = [mock_user1, mock_user2]
else:
result.scalars.return_value.first.return_value = mock_workspace
call_count[0] += 1
return result
mock_session.execute = AsyncMock(side_effect=execute_side_effect)
mock_session.add = Mock()
mock_session.commit = AsyncMock()
@asynccontextmanager
async def mock_context_manager():
yield mock_session
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
await store.deactivate_workspace(1)
# Assert
assert mock_user1.status == 'inactive'
assert mock_user2.status == 'inactive'
assert mock_workspace.status == 'inactive'
mock_session.commit.assert_called_once()

View File

@@ -0,0 +1,183 @@
"""
Unit tests for OrgAppSettingsStore.
Tests the async database operations for organization app settings.
"""
import uuid
import pytest
from server.routes.org_models import OrgAppSettingsUpdate
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.org import Org
from storage.org_app_settings_store import OrgAppSettingsStore
from storage.user import User
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker for testing."""
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
@pytest.mark.asyncio
async def test_get_current_org_by_user_id_success(async_session_maker):
"""
GIVEN: A user exists with a current organization
WHEN: get_current_org_by_user_id is called with the user's ID
THEN: The organization is returned with correct data
"""
# Arrange
async with async_session_maker() as session:
org = Org(
name='test-org',
enable_proactive_conversation_starters=True,
enable_solvability_analysis=False,
max_budget_per_task=25.0,
)
session.add(org)
await session.flush()
user = User(
id=uuid.uuid4(),
current_org_id=org.id,
)
session.add(user)
await session.commit()
user_id = str(user.id)
# Act
store = OrgAppSettingsStore(db_session=session)
result = await store.get_current_org_by_user_id(user_id)
# Assert
assert result is not None
assert result.name == 'test-org'
assert result.enable_proactive_conversation_starters is True
assert result.enable_solvability_analysis is False
assert result.max_budget_per_task == 25.0
@pytest.mark.asyncio
async def test_get_current_org_by_user_id_user_not_found(async_session_maker):
"""
GIVEN: A user does not exist in the database
WHEN: get_current_org_by_user_id is called with a non-existent ID
THEN: None is returned
"""
# Arrange
non_existent_id = str(uuid.uuid4())
# Act
async with async_session_maker() as session:
store = OrgAppSettingsStore(db_session=session)
result = await store.get_current_org_by_user_id(non_existent_id)
# Assert
assert result is None
@pytest.mark.asyncio
async def test_update_org_app_settings_success(async_session_maker):
"""
GIVEN: An organization exists in the database
WHEN: update_org_app_settings is called with new values
THEN: The organization's settings are updated and returned
"""
# Arrange
async with async_session_maker() as session:
org = Org(
name='test-org',
enable_proactive_conversation_starters=True,
enable_solvability_analysis=False,
max_budget_per_task=10.0,
)
session.add(org)
await session.commit()
org_id = org.id
update_data = OrgAppSettingsUpdate(
enable_proactive_conversation_starters=False,
enable_solvability_analysis=True,
max_budget_per_task=50.0,
)
# Act
store = OrgAppSettingsStore(db_session=session)
result = await store.update_org_app_settings(org_id, update_data)
# Assert
assert result is not None
assert result.enable_proactive_conversation_starters is False
assert result.enable_solvability_analysis is True
assert result.max_budget_per_task == 50.0
@pytest.mark.asyncio
async def test_update_org_app_settings_partial(async_session_maker):
"""
GIVEN: An organization exists with existing settings
WHEN: update_org_app_settings is called with only some fields
THEN: Only the provided fields are updated, others remain unchanged
"""
# Arrange
async with async_session_maker() as session:
org = Org(
name='test-org',
enable_proactive_conversation_starters=True,
enable_solvability_analysis=False,
max_budget_per_task=10.0,
)
session.add(org)
await session.commit()
org_id = org.id
# Only update max_budget_per_task
update_data = OrgAppSettingsUpdate(max_budget_per_task=100.0)
# Act
store = OrgAppSettingsStore(db_session=session)
result = await store.update_org_app_settings(org_id, update_data)
# Assert
assert result is not None
assert result.max_budget_per_task == 100.0
assert result.enable_proactive_conversation_starters is True # Unchanged
assert result.enable_solvability_analysis is False # Unchanged
@pytest.mark.asyncio
async def test_update_org_app_settings_org_not_found(async_session_maker):
"""
GIVEN: An organization does not exist in the database
WHEN: update_org_app_settings is called
THEN: None is returned
"""
# Arrange
non_existent_id = uuid.uuid4()
update_data = OrgAppSettingsUpdate(enable_proactive_conversation_starters=False)
# Act
async with async_session_maker() as session:
store = OrgAppSettingsStore(db_session=session)
result = await store.update_org_app_settings(non_existent_id, update_data)
# Assert
assert result is None

View File

@@ -0,0 +1,175 @@
"""
Unit tests for OrgLLMSettingsStore.
Tests the async database operations for organization LLM settings.
"""
import uuid
from unittest.mock import AsyncMock, patch
import pytest
from server.routes.org_models import OrgLLMSettingsUpdate
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.org import Org
from storage.org_llm_settings_store import OrgLLMSettingsStore
from storage.user import User
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker for testing."""
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
@pytest.mark.asyncio
async def test_get_current_org_by_user_id_success(async_session_maker):
"""
GIVEN: A user exists with a current_org_id
WHEN: get_current_org_by_user_id is called
THEN: The user's current organization is returned
"""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org', default_llm_model='claude-3')
session.add(org)
await session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id)
session.add(user)
await session.commit()
user_id = str(user.id)
# Act
store = OrgLLMSettingsStore(db_session=session)
result = await store.get_current_org_by_user_id(user_id)
# Assert
assert result is not None
assert result.name == 'test-org'
assert result.default_llm_model == 'claude-3'
@pytest.mark.asyncio
async def test_get_current_org_by_user_id_user_not_found(async_session_maker):
"""
GIVEN: A user does not exist in the database
WHEN: get_current_org_by_user_id is called
THEN: None is returned
"""
# Arrange
non_existent_id = str(uuid.uuid4())
# Act
async with async_session_maker() as session:
store = OrgLLMSettingsStore(db_session=session)
result = await store.get_current_org_by_user_id(non_existent_id)
# Assert
assert result is None
@pytest.mark.asyncio
async def test_update_org_llm_settings_success(async_session_maker):
"""
GIVEN: An organization exists in the database
WHEN: update_org_llm_settings is called with new values
THEN: The organization's LLM settings are updated and returned
"""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org', default_llm_model='old-model')
session.add(org)
await session.commit()
org_id = org.id
update_data = OrgLLMSettingsUpdate(
default_llm_model='new-model',
agent='CodeActAgent',
confirmation_mode=True,
)
# Act
store = OrgLLMSettingsStore(db_session=session)
with patch(
'storage.org_llm_settings_store.OrgMemberStore.update_all_members_llm_settings_async',
AsyncMock(),
):
result = await store.update_org_llm_settings(org_id, update_data)
# Assert
assert result is not None
assert result.default_llm_model == 'new-model'
assert result.agent == 'CodeActAgent'
assert result.confirmation_mode is True
@pytest.mark.asyncio
async def test_update_org_llm_settings_org_not_found(async_session_maker):
"""
GIVEN: An organization does not exist in the database
WHEN: update_org_llm_settings is called
THEN: None is returned
"""
# Arrange
non_existent_org_id = uuid.uuid4()
update_data = OrgLLMSettingsUpdate(default_llm_model='new-model')
# Act
async with async_session_maker() as session:
store = OrgLLMSettingsStore(db_session=session)
result = await store.update_org_llm_settings(non_existent_org_id, update_data)
# Assert
assert result is None
@pytest.mark.asyncio
async def test_update_org_llm_settings_propagates_to_members(async_session_maker):
"""
GIVEN: An organization exists with update data containing member-relevant settings
WHEN: update_org_llm_settings is called
THEN: Member settings are propagated via OrgMemberStore
"""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org', default_llm_model='old-model')
session.add(org)
await session.commit()
org_id = org.id
update_data = OrgLLMSettingsUpdate(
default_llm_model='new-model',
llm_api_key='new-api-key',
)
# Act
store = OrgLLMSettingsStore(db_session=session)
with patch(
'storage.org_llm_settings_store.OrgMemberStore.update_all_members_llm_settings_async',
AsyncMock(),
) as mock_update_members:
await store.update_org_llm_settings(org_id, update_data)
# Assert
mock_update_members.assert_called_once()
call_args = mock_update_members.call_args
member_settings = call_args[0][2]
assert member_settings.llm_model == 'new-model'
assert member_settings.llm_api_key == 'new-api-key'

Some files were not shown because too many files have changed in this diff Show More