Compare commits

..

3 Commits

Author SHA1 Message Date
openhands
d4f7f07d5d test: add comprehensive tests for v1-git-service query parameter changes
- Add tests verifying query parameters are used instead of path segments
- Add tests for preserving slashes in paths (main fix purpose)
- Add tests for session API key headers
- Add tests for V1 to V0 status mapping
- Add tests for getGitChangeDiff endpoint

Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-02 02:59:27 +00:00
chuckbutkus
a34dc949ce Merge branch 'main' into fix/git-api-use-query-params 2026-03-01 21:39:52 -05:00
openhands
80e4fe1226 fix: use query parameters for V1 git API endpoints to preserve path slashes
Update V1GitService to pass path as a query parameter instead of embedding
it in the URL path segment. This fixes URL path normalization issues with
Traefik/Gateway API where encoded slashes (%2F) in path segments would be
decoded and normalized, causing leading slashes to be lost.

For example, /workspace/project was arriving as workspace/project.

Using query parameters (e.g., ?path=/workspace/project) avoids this issue
as they are passed through without path normalization.

Requires corresponding backend change in software-agent-sdk.

Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-01 05:09:30 +00:00
125 changed files with 3333 additions and 4808 deletions

View File

@@ -193,20 +193,14 @@ class GithubManager(Manager):
github_view.installation_id
)
# Store the installation token
await self.token_manager.store_org_token(
self.token_manager.store_org_token(
github_view.installation_id, installation_token
)
# Add eyes reaction to acknowledge we've read the request
self._add_reaction(github_view, 'eyes', installation_token)
await self.start_job(github_view)
async def send_message(self, message: str, github_view: ResolverViewInterface):
"""Send a message to GitHub.
Args:
message: The message content to send (plain text string)
github_view: The GitHub view object containing issue/PR/comment info
"""
async def send_message(self, message: Message, github_view: ResolverViewInterface):
installation_token = self.token_manager.load_org_token(
github_view.installation_id
)
@@ -214,12 +208,14 @@ class GithubManager(Manager):
logger.warning('Missing installation token')
return
outgoing_message = message.message
if isinstance(github_view, GithubInlinePRComment):
with Github(auth=Auth.Token(installation_token)) as github_client:
repo = github_client.get_repo(github_view.full_repo_name)
pr = repo.get_pull(github_view.issue_number)
pr.create_review_comment_reply(
comment_id=github_view.comment_id, body=message
comment_id=github_view.comment_id, body=outgoing_message
)
elif (
@@ -230,7 +226,7 @@ class GithubManager(Manager):
with Github(auth=Auth.Token(installation_token)) as github_client:
repo = github_client.get_repo(github_view.full_repo_name)
issue = repo.get_issue(number=github_view.issue_number)
issue.create_comment(message)
issue.create_comment(outgoing_message)
else:
logger.warning('Unsupported location')
@@ -249,7 +245,7 @@ class GithubManager(Manager):
)
try:
msg_info: str = ''
msg_info = None
try:
user_info = github_view.user_info
@@ -365,13 +361,15 @@ class GithubManager(Manager):
msg_info = get_session_expired_message(user_info.username)
await self.send_message(msg_info, github_view)
msg = self.create_outgoing_message(msg_info)
await self.send_message(msg, github_view)
except Exception:
logger.exception('[Github]: Error starting job')
await self.send_message(
'Uh oh! There was an unexpected error starting the job :(', github_view
msg = self.create_outgoing_message(
msg='Uh oh! There was an unexpected error starting the job :('
)
await self.send_message(msg, github_view)
try:
await self.data_collector.save_data(github_view)

View File

@@ -14,6 +14,7 @@ from integrations.solvability.models.summary import SolvabilitySummary
from integrations.utils import ENABLE_SOLVABILITY_ANALYSIS
from pydantic import ValidationError
from server.config import get_config
from storage.database import session_maker
from storage.saas_settings_store import SaasSettingsStore
from openhands.core.config import LLMConfig
@@ -89,6 +90,7 @@ async def summarize_issue_solvability(
# Grab the user's information so we can load their LLM configuration
store = SaasSettingsStore(
user_id=github_view.user_info.keycloak_user_id,
session_maker=session_maker,
config=get_config(),
)

View File

@@ -24,6 +24,7 @@ from jinja2 import Environment
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
from server.auth.token_manager import TokenManager
from server.config import get_config
from storage.database import session_maker
from storage.org_store import OrgStore
from storage.proactive_conversation_store import ProactiveConversationStore
from storage.saas_secrets_store import SaasSecretsStore
@@ -152,7 +153,9 @@ class GithubIssue(ResolverViewInterface):
return user_instructions, conversation_instructions
async def _get_user_secrets(self):
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
secrets_store = SaasSecretsStore(
self.user_info.keycloak_user_id, session_maker, get_config()
)
user_secrets = await secrets_store.load()
return user_secrets.custom_secrets if user_secrets else None

View File

@@ -121,11 +121,12 @@ class GitlabManager(Manager):
# Check if the user has write access to the repository
return has_write_access
async def send_message(self, message: str, gitlab_view: ResolverViewInterface):
"""Send a message to GitLab based on the view type.
async def send_message(self, message: Message, gitlab_view: ResolverViewInterface):
"""
Send a message to GitLab based on the view type.
Args:
message: The message content to send (plain text string)
message: The message to send
gitlab_view: The GitLab view object containing issue/PR/comment info
"""
keycloak_user_id = gitlab_view.user_info.keycloak_user_id
@@ -137,6 +138,8 @@ class GitlabManager(Manager):
external_auth_id=keycloak_user_id
)
outgoing_message = message.message
if isinstance(gitlab_view, GitlabInlineMRComment) or isinstance(
gitlab_view, GitlabMRComment
):
@@ -144,7 +147,7 @@ class GitlabManager(Manager):
gitlab_view.project_id,
gitlab_view.issue_number,
gitlab_view.discussion_id,
message,
message.message,
)
elif isinstance(gitlab_view, GitlabIssueComment):
@@ -152,14 +155,14 @@ class GitlabManager(Manager):
gitlab_view.project_id,
gitlab_view.issue_number,
gitlab_view.discussion_id,
message,
outgoing_message,
)
elif isinstance(gitlab_view, GitlabIssue):
await gitlab_service.reply_to_issue(
gitlab_view.project_id,
gitlab_view.issue_number,
None, # no discussion id, issue is tagged
message,
outgoing_message,
)
else:
logger.warning(
@@ -259,10 +262,12 @@ class GitlabManager(Manager):
msg_info = get_session_expired_message(user_info.username)
# Send the acknowledgment message
await self.send_message(msg_info, gitlab_view)
msg = self.create_outgoing_message(msg_info)
await self.send_message(msg, gitlab_view)
except Exception as e:
logger.exception(f'[GitLab] Error starting job: {str(e)}')
await self.send_message(
'Uh oh! There was an unexpected error starting the job :(', gitlab_view
msg = self.create_outgoing_message(
msg='Uh oh! There was an unexpected error starting the job :('
)
await self.send_message(msg, gitlab_view)

View File

@@ -6,6 +6,7 @@ from integrations.utils import HOST, get_oh_labels, has_exact_mention
from jinja2 import Environment
from server.auth.token_manager import TokenManager
from server.config import get_config
from storage.database import session_maker
from storage.saas_secrets_store import SaasSecretsStore
from openhands.core.logger import openhands_logger as logger
@@ -77,7 +78,9 @@ class GitlabIssue(ResolverViewInterface):
return user_instructions, conversation_instructions
async def _get_user_secrets(self):
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
secrets_store = SaasSecretsStore(
self.user_info.keycloak_user_id, session_maker, get_config()
)
user_secrets = await secrets_store.load()
return user_secrets.custom_secrets if user_secrets else None
@@ -446,5 +449,3 @@ class GitlabFactory:
previous_comments=[],
is_mr=True,
)
raise ValueError(f'Unhandled GitLab webhook event: {message}')

View File

@@ -341,25 +341,17 @@ class JiraManager(Manager):
async def send_message(
self,
message: str,
message: Message,
issue_key: str,
jira_cloud_id: str,
svc_acc_email: str,
svc_acc_api_key: str,
):
"""Send a comment to a Jira issue.
Args:
message: The message content to send (plain text string)
issue_key: The Jira issue key (e.g., 'PROJ-123')
jira_cloud_id: The Jira Cloud ID
svc_acc_email: Service account email for authentication
svc_acc_api_key: Service account API key for authentication
"""
"""Send a comment to a Jira issue."""
url = (
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
)
data = {'body': message}
data = {'body': message.message}
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
response = await client.post(
url, auth=(svc_acc_email, svc_acc_api_key), json=data
@@ -374,7 +366,7 @@ class JiraManager(Manager):
view.jira_workspace.svc_acc_api_key
)
await self.send_message(
msg,
self.create_outgoing_message(msg=msg),
issue_key=view.payload.issue_key,
jira_cloud_id=view.jira_workspace.jira_cloud_id,
svc_acc_email=view.jira_workspace.svc_acc_email,
@@ -396,7 +388,7 @@ class JiraManager(Manager):
try:
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
await self.send_message(
error_msg,
self.create_outgoing_message(msg=error_msg),
issue_key=payload.issue_key,
jira_cloud_id=workspace.jira_cloud_id,
svc_acc_email=workspace.svc_acc_email,

View File

@@ -212,6 +212,8 @@ class JiraPayloadParser:
missing.append('issue.id')
if not issue_key:
missing.append('issue.key')
if not user_email:
missing.append('user.emailAddress')
if not display_name:
missing.append('user.displayName')
if not account_id:

View File

@@ -418,7 +418,7 @@ class JiraDcManager(Manager):
jira_dc_view.jira_dc_workspace.svc_acc_api_key
)
await self.send_message(
msg_info,
self.create_outgoing_message(msg=msg_info),
issue_key=jira_dc_view.job_context.issue_key,
base_api_url=jira_dc_view.job_context.base_api_url,
svc_acc_api_key=api_key,
@@ -456,19 +456,12 @@ class JiraDcManager(Manager):
return title, description
async def send_message(
self, message: str, issue_key: str, base_api_url: str, svc_acc_api_key: str
self, message: Message, issue_key: str, base_api_url: str, svc_acc_api_key: str
):
"""Send message/comment to Jira DC issue.
Args:
message: The message content to send (plain text string)
issue_key: The Jira issue key (e.g., 'PROJ-123')
base_api_url: The base API URL for the Jira DC instance
svc_acc_api_key: Service account API key for authentication
"""
"""Send message/comment to Jira DC issue."""
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
data = {'body': message}
data = {'body': message.message}
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
response = await client.post(url, headers=headers, json=data)
response.raise_for_status()
@@ -488,7 +481,7 @@ class JiraDcManager(Manager):
try:
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
await self.send_message(
error_msg,
self.create_outgoing_message(msg=error_msg),
issue_key=job_context.issue_key,
base_api_url=job_context.base_api_url,
svc_acc_api_key=api_key,
@@ -509,7 +502,7 @@ class JiraDcManager(Manager):
)
await self.send_message(
comment_msg,
self.create_outgoing_message(msg=comment_msg),
issue_key=jira_dc_view.job_context.issue_key,
base_api_url=jira_dc_view.job_context.base_api_url,
svc_acc_api_key=api_key,

View File

@@ -19,7 +19,7 @@ class JiraDcViewInterface(ABC):
conversation_id: str
@abstractmethod
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Get initial instructions for the conversation."""
pass

View File

@@ -36,7 +36,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
selected_repo: str | None
conversation_id: str
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
instructions_template = jinja_env.get_template('jira_dc_instructions.j2')
@@ -61,7 +61,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
provider_tokens = await self.saas_user_auth.get_provider_tokens()
user_secrets = await self.saas_user_auth.get_secrets()
instructions, user_msg = await self._get_instructions(jinja_env)
instructions, user_msg = self._get_instructions(jinja_env)
try:
agent_loop_info = await create_new_conversation(
@@ -113,7 +113,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
selected_repo: str | None
conversation_id: str
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
user_msg_template = jinja_env.get_template('jira_dc_existing_conversation.j2')
@@ -167,7 +167,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
if not agent_state or agent_state == AgentState.LOADING:
raise StartingConvoException('Conversation is still starting')
_, user_msg = await self._get_instructions(jinja_env)
_, user_msg = self._get_instructions(jinja_env)
user_message_event = MessageAction(content=user_msg)
await conversation_manager.send_event_to_conversation(
self.conversation_id, event_to_dict(user_message_event)

View File

@@ -408,7 +408,7 @@ class LinearManager(Manager):
linear_view.linear_workspace.svc_acc_api_key
)
await self.send_message(
msg_info,
self.create_outgoing_message(msg=msg_info),
linear_view.job_context.issue_id,
api_key,
)
@@ -473,14 +473,8 @@ class LinearManager(Manager):
return title, description
async def send_message(self, message: str, issue_id: str, api_key: str):
"""Send message/comment to Linear issue.
Args:
message: The message content to send (plain text string)
issue_id: The Linear issue ID to comment on
api_key: The Linear API key for authentication
"""
async def send_message(self, message: Message, issue_id: str, api_key: str):
"""Send message/comment to Linear issue."""
query = """
mutation CommentCreate($input: CommentCreateInput!) {
commentCreate(input: $input) {
@@ -491,7 +485,7 @@ class LinearManager(Manager):
}
}
"""
variables = {'input': {'issueId': issue_id, 'body': message}}
variables = {'input': {'issueId': issue_id, 'body': message.message}}
return await self._query_api(query, variables, api_key)
async def _send_error_comment(
@@ -504,7 +498,9 @@ class LinearManager(Manager):
try:
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
await self.send_message(error_msg, issue_id, api_key)
await self.send_message(
self.create_outgoing_message(msg=error_msg), issue_id, api_key
)
except Exception as e:
logger.error(f'[Linear] Failed to send error comment: {str(e)}')
@@ -521,7 +517,7 @@ class LinearManager(Manager):
)
await self.send_message(
comment_msg,
self.create_outgoing_message(msg=comment_msg),
linear_view.job_context.issue_id,
api_key,
)

View File

@@ -19,7 +19,7 @@ class LinearViewInterface(ABC):
conversation_id: str
@abstractmethod
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Get initial instructions for the conversation."""
pass

View File

@@ -33,7 +33,7 @@ class LinearNewConversationView(LinearViewInterface):
selected_repo: str | None
conversation_id: str
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
instructions_template = jinja_env.get_template('linear_instructions.j2')
@@ -58,7 +58,7 @@ class LinearNewConversationView(LinearViewInterface):
provider_tokens = await self.saas_user_auth.get_provider_tokens()
user_secrets = await self.saas_user_auth.get_secrets()
instructions, user_msg = await self._get_instructions(jinja_env)
instructions, user_msg = self._get_instructions(jinja_env)
try:
agent_loop_info = await create_new_conversation(
@@ -110,7 +110,7 @@ class LinearExistingConversationView(LinearViewInterface):
selected_repo: str | None
conversation_id: str
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
user_msg_template = jinja_env.get_template('linear_existing_conversation.j2')
@@ -164,7 +164,7 @@ class LinearExistingConversationView(LinearViewInterface):
if not agent_state or agent_state == AgentState.LOADING:
raise StartingConvoException('Conversation is still starting')
_, user_msg = await self._get_instructions(jinja_env)
_, user_msg = self._get_instructions(jinja_env)
user_message_event = MessageAction(content=user_msg)
await conversation_manager.send_event_to_conversation(
self.conversation_id, event_to_dict(user_message_event)

View File

@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from typing import Any
from integrations.models import Message, SourceType
@@ -13,15 +12,14 @@ class Manager(ABC):
raise NotImplementedError
@abstractmethod
def send_message(self, message: str, *args: Any, **kwargs: Any):
"""Send message to integration from OpenHands server.
Args:
message: The message content to send (plain text string).
"""
def send_message(self, message: Message):
"Send message to integration from Openhands server"
raise NotImplementedError
@abstractmethod
def start_job(self):
"Kick off a job with openhands agent"
raise NotImplementedError
def create_outgoing_message(self, msg: str | dict, ephemeral: bool = False):
return Message(source=SourceType.OPENHANDS, message=msg, ephemeral=ephemeral)

View File

@@ -1,5 +1,4 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
@@ -17,16 +16,8 @@ class SourceType(str, Enum):
class Message(BaseModel):
"""Message model for incoming webhook payloads from integrations.
Note: This model is intended for INCOMING messages only.
For outgoing messages (e.g., sending comments to GitHub/GitLab),
pass strings directly to the send_message methods instead of
wrapping them in a Message object.
"""
source: SourceType
message: dict[str, Any]
message: str | dict
ephemeral: bool = False

View File

@@ -1,5 +1,4 @@
import re
from typing import Any
import jwt
from integrations.manager import Manager
@@ -23,8 +22,7 @@ from server.constants import SLACK_CLIENT_ID
from server.utils.conversation_callback_utils import register_callback_processor
from slack_sdk.oauth import AuthorizeUrlGenerator
from slack_sdk.web.async_client import AsyncWebClient
from sqlalchemy import select
from storage.database import a_session_maker
from storage.database import session_maker
from storage.slack_user import SlackUser
from openhands.core.logger import openhands_logger as logger
@@ -65,11 +63,12 @@ class SlackManager(Manager):
) -> tuple[SlackUser | None, UserAuth | None]:
# We get the user and correlate them back to a user in OpenHands - if we can
slack_user = None
async with a_session_maker() as session:
result = await session.execute(
select(SlackUser).where(SlackUser.slack_user_id == slack_user_id)
with session_maker() as session:
slack_user = (
session.query(SlackUser)
.filter(SlackUser.slack_user_id == slack_user_id)
.first()
)
slack_user = result.scalar_one_or_none()
# slack_view.slack_to_openhands_user = slack_user # attach user auth info to view
@@ -203,7 +202,9 @@ class SlackManager(Manager):
msg = self.login_link.format(link)
logger.info('slack_not_yet_authenticated')
await self.send_message(msg, slack_view, ephemeral=True)
await self.send_message(
self.create_outgoing_message(msg, ephemeral=True), slack_view
)
return
if not await self.is_job_requested(message, slack_view):
@@ -211,40 +212,27 @@ class SlackManager(Manager):
await self.start_job(slack_view)
async def send_message(
self,
message: str | dict[str, Any],
slack_view: SlackViewInterface,
ephemeral: bool = False,
):
"""Send a message to Slack.
Args:
message: The message content. Can be a string (for simple text) or
a dict with 'text' and 'blocks' keys (for structured messages).
slack_view: The Slack view object containing channel/thread info.
ephemeral: If True, send as an ephemeral message visible only to the user.
"""
async def send_message(self, message: Message, slack_view: SlackViewInterface):
client = AsyncWebClient(token=slack_view.bot_access_token)
if ephemeral and isinstance(message, str):
if message.ephemeral and isinstance(message.message, str):
await client.chat_postEphemeral(
channel=slack_view.channel_id,
markdown_text=message,
markdown_text=message.message,
user=slack_view.slack_user_id,
thread_ts=slack_view.thread_ts,
)
elif ephemeral and isinstance(message, dict):
elif message.ephemeral and isinstance(message.message, dict):
await client.chat_postEphemeral(
channel=slack_view.channel_id,
user=slack_view.slack_user_id,
thread_ts=slack_view.thread_ts,
text=message['text'],
blocks=message['blocks'],
text=message.message['text'],
blocks=message.message['blocks'],
)
else:
await client.chat_postMessage(
channel=slack_view.channel_id,
markdown_text=message,
markdown_text=message.message,
thread_ts=slack_view.message_ts,
)
@@ -291,7 +279,10 @@ class SlackManager(Manager):
repos, slack_view.message_ts, slack_view.thread_ts
),
}
await self.send_message(repo_selection_msg, slack_view, ephemeral=True)
await self.send_message(
self.create_outgoing_message(repo_selection_msg, ephemeral=True),
slack_view,
)
return False
@@ -377,10 +368,9 @@ class SlackManager(Manager):
except StartingConvoException as e:
msg_info = str(e)
await self.send_message(msg_info, slack_view)
await self.send_message(self.create_outgoing_message(msg_info), slack_view)
except Exception:
logger.exception('[Slack]: Error starting job')
await self.send_message(
'Uh oh! There was an unexpected error starting the job :(', slack_view
)
msg = 'Uh oh! There was an unexpected error starting the job :('
await self.send_message(self.create_outgoing_message(msg), slack_view)

View File

@@ -24,7 +24,7 @@ class SlackViewInterface(SummaryExtractionTracker, ABC):
v1_enabled: bool
@abstractmethod
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
pass

View File

@@ -75,7 +75,7 @@ class SlackUnkownUserView(SlackViewInterface):
team_id: str
v1_enabled: bool
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
raise NotImplementedError
async def create_or_update_conversation(self, jinja_env: Environment):
@@ -118,7 +118,7 @@ class SlackNewConversationView(SlackViewInterface):
return block['user_id']
return ''
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
user_info: SlackUser = self.slack_to_openhands_user
@@ -242,9 +242,7 @@ class SlackNewConversationView(SlackViewInterface):
self, jinja: Environment, provider_tokens, user_secrets
) -> None:
"""Create conversation using the legacy V0 system."""
user_instructions, conversation_instructions = await self._get_instructions(
jinja
)
user_instructions, conversation_instructions = self._get_instructions(jinja)
# Determine git provider from repository
git_provider = None
@@ -275,9 +273,7 @@ class SlackNewConversationView(SlackViewInterface):
async def _create_v1_conversation(self, jinja: Environment) -> None:
"""Create conversation using the new V1 app conversation system."""
user_instructions, conversation_instructions = await self._get_instructions(
jinja
)
user_instructions, conversation_instructions = self._get_instructions(jinja)
# Create the initial message request
initial_message = SendMessageRequest(
@@ -350,7 +346,7 @@ class SlackNewConversationFromRepoFormView(SlackNewConversationView):
class SlackUpdateExistingConversationView(SlackNewConversationView):
slack_conversation: SlackConversation
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
client = WebClient(token=self.bot_access_token)
result = client.conversations_replies(
channel=self.channel_id,
@@ -405,7 +401,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
if not agent_state or agent_state == AgentState.LOADING:
raise StartingConvoException('Conversation is still starting')
instructions, _ = await self._get_instructions(jinja)
instructions, _ = self._get_instructions(jinja)
user_msg = MessageAction(content=instructions)
await conversation_manager.send_event_to_conversation(
self.conversation_id, event_to_dict(user_msg)
@@ -473,7 +469,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
agent_server_url = get_agent_server_url_from_sandbox(running_sandbox)
# 4. Prepare the message content
user_msg, _ = await self._get_instructions(jinja)
user_msg, _ = self._get_instructions(jinja)
# 5. Create the message request
send_message_request = SendMessageRequest(

View File

@@ -42,11 +42,11 @@ async def store_repositories_in_db(repos: list[Repository], user_id: str) -> Non
try:
# Store repositories in the repos table
repo_store = RepositoryStore.get_instance(config)
await repo_store.store_projects(stored_repos)
repo_store.store_projects(stored_repos)
# Store user-repository mappings in the user-repos table
user_repo_store = UserRepositoryMapStore.get_instance(config)
await user_repo_store.store_user_repo_mappings(user_repos)
user_repo_store.store_user_repo_mappings(user_repos)
logger.info(f'Saved repos for user {user_id}')
except Exception:

View File

@@ -3,8 +3,8 @@ from uuid import UUID
import stripe
from server.constants import STRIPE_API_KEY
from server.logger import logger
from sqlalchemy import select
from storage.database import a_session_maker
from sqlalchemy.orm import Session
from storage.database import session_maker
from storage.org import Org
from storage.org_store import OrgStore
from storage.stripe_customer import StripeCustomer
@@ -15,10 +15,12 @@ stripe.api_key = STRIPE_API_KEY
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
async with a_session_maker() as session:
stmt = select(StripeCustomer).where(StripeCustomer.org_id == org_id)
result = await session.execute(stmt)
stripe_customer = result.scalar_one_or_none()
with session_maker() as session:
stripe_customer = (
session.query(StripeCustomer)
.filter(StripeCustomer.org_id == org_id)
.first()
)
if stripe_customer:
return stripe_customer.stripe_customer_id
@@ -72,7 +74,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
)
# Save the stripe customer in the local db
async with a_session_maker() as session:
with session_maker() as session:
session.add(
StripeCustomer(
keycloak_user_id=user_id,
@@ -80,7 +82,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
stripe_customer_id=customer.id,
)
)
await session.commit()
session.commit()
logger.info(
'created_customer',
@@ -106,27 +108,26 @@ async def has_payment_method_by_user_id(user_id: str) -> bool:
return bool(payment_methods.data)
async def migrate_customer(user_id: str, org: Org):
async with a_session_maker() as session:
result = await session.execute(
select(StripeCustomer).where(StripeCustomer.keycloak_user_id == user_id)
)
stripe_customer = result.scalar_one_or_none()
if stripe_customer is None:
return
stripe_customer.org_id = org.id
customer = await stripe.Customer.modify_async(
id=stripe_customer.stripe_customer_id,
email=org.contact_email,
metadata={'user_id': '', 'org_id': str(org.id)},
)
async def migrate_customer(session: Session, user_id: str, org: Org):
stripe_customer = (
session.query(StripeCustomer)
.filter(StripeCustomer.keycloak_user_id == user_id)
.first()
)
if stripe_customer is None:
return
stripe_customer.org_id = org.id
customer = await stripe.Customer.modify_async(
id=stripe_customer.stripe_customer_id,
email=org.contact_email,
metadata={'user_id': '', 'org_id': str(org.id)},
)
logger.info(
'migrated_customer',
extra={
'user_id': user_id,
'org_id': str(org.id),
'stripe_customer_id': customer.id,
},
)
await session.commit()
logger.info(
'migrated_customer',
extra={
'user_id': user_id,
'org_id': str(org.id),
'stripe_customer_id': customer.id,
},
)

View File

@@ -38,7 +38,7 @@ class ResolverViewInterface(SummaryExtractionTracker):
is_public_repo: bool
raw_payload: dict
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"Instructions passed when conversation is first initialized"
raise NotImplementedError()

17
enterprise/poetry.lock generated
View File

@@ -1591,9 +1591,6 @@ files = [
{file = "cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c"},
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a"},
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356"},
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da"},
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257"},
{file = "cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7"},
{file = "cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d"},
]
@@ -6171,23 +6168,23 @@ opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
pathspec = ">=0.12.1"
pexpect = "*"
pg8000 = ">=1.31.5"
pillow = ">=12.1.1"
pillow = ">=11.3"
playwright = ">=1.55"
poetry = ">=2.1.2"
prompt-toolkit = ">=3.0.50"
protobuf = ">=5.29.6,<6"
protobuf = ">=5,<6"
psutil = "*"
pybase62 = ">=1"
pygithub = ">=2.5"
pyjwt = ">=2.9"
pylatexenc = "*"
pypdf = ">=6.7.2"
pypdf = ">=6"
python-docx = "*"
python-dotenv = "*"
python-frontmatter = ">=1.1"
python-jose = {version = ">=3.3", extras = ["cryptography"]}
python-json-logger = ">=3.2.1"
python-multipart = ">=0.0.22"
python-multipart = "*"
python-pptx = "*"
python-socketio = "5.14"
pythonnet = "*"
@@ -6200,7 +6197,7 @@ setuptools = ">=78.1.1"
shellingham = ">=1.5.4"
sqlalchemy = {version = ">=2.0.40", extras = ["asyncio"]}
sse-starlette = ">=3.0.2"
starlette = ">=0.49.1"
starlette = ">=0.48"
tenacity = ">=8.5,<10"
termcolor = "*"
toml = "*"
@@ -11964,7 +11961,7 @@ description = "Python for Window Extensions"
optional = false
python-versions = "*"
groups = ["main"]
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
markers = "platform_system == \"Windows\" or sys_platform == \"win32\""
files = [
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
@@ -14912,4 +14909,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = "^3.12,<3.14"
content-hash = "ef037f6d6085d26166d35c56ce266439f8f1a4fea90bc43ccf15cfeaf116cae5"
content-hash = "1cad6029269393af67155e930c72eae2c03da02e4b3a3699823f6168c14a4218"

View File

@@ -49,7 +49,7 @@ prometheus-client = "^0.24.0"
pandas = "^2.2.0"
numpy = "^2.2.0"
mcp = "^1.10.0"
pillow = "^12.1.1"
pillow = "^12.1.0"
[tool.poetry.group.dev.dependencies]
ruff = "0.8.3"

View File

@@ -1,4 +1,5 @@
from storage.blocked_email_domain_store import BlockedEmailDomainStore
from storage.database import session_maker
from openhands.core.logger import openhands_logger as logger
@@ -22,7 +23,7 @@ class DomainBlocker:
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
return None
async def is_domain_blocked(self, email: str) -> bool:
def is_domain_blocked(self, email: str) -> bool:
"""Check if email domain is blocked by querying the database directly via SQL.
Supports blocking:
@@ -44,7 +45,7 @@ class DomainBlocker:
try:
# Query database directly via SQL to check if domain is blocked
is_blocked = await self.store.is_domain_blocked(domain)
is_blocked = self.store.is_domain_blocked(domain)
if is_blocked:
logger.warning(f'Email domain {domain} is blocked for email: {email}')
@@ -62,5 +63,5 @@ class DomainBlocker:
# Initialize store and domain blocker
_store = BlockedEmailDomainStore()
_store = BlockedEmailDomainStore(session_maker=session_maker)
domain_blocker = DomainBlocker(store=_store)

View File

@@ -1,7 +1,7 @@
from integrations.github.github_service import SaaSGitHubService
from pydantic import SecretStr
from server.auth.auth_utils import user_verifier
from enterprise.server.auth.auth_utils import user_verifier
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.github_types import GitHubUser

View File

@@ -18,10 +18,9 @@ from server.auth.token_manager import TokenManager
from server.config import get_config
from server.logger import logger
from server.rate_limit import RateLimiter, create_redis_rate_limiter
from sqlalchemy import delete, select
from storage.api_key_store import ApiKeyStore
from storage.auth_tokens import AuthTokens
from storage.database import a_session_maker
from storage.database import session_maker
from storage.saas_secrets_store import SaasSecretsStore
from storage.saas_settings_store import SaasSettingsStore
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
@@ -125,7 +124,7 @@ class SaasUserAuth(UserAuth):
if secrets_store:
return secrets_store
user_id = await self.get_user_id()
secrets_store = SaasSecretsStore(user_id, get_config())
secrets_store = SaasSecretsStore(user_id, session_maker, get_config())
self.secrets_store = secrets_store
return secrets_store
@@ -162,13 +161,12 @@ class SaasUserAuth(UserAuth):
try:
# TODO: I think we can do this in a single request if we refactor
async with a_session_maker() as session:
result = await session.execute(
select(AuthTokens).where(
AuthTokens.keycloak_user_id == self.user_id
)
with session_maker() as session:
tokens = (
session.query(AuthTokens)
.where(AuthTokens.keycloak_user_id == self.user_id)
.all()
)
tokens = result.scalars().all()
for token in tokens:
idp_type = ProviderType(token.identity_provider)
@@ -194,11 +192,11 @@ class SaasUserAuth(UserAuth):
'idp_type': token.identity_provider,
},
)
async with a_session_maker() as session:
await session.execute(
delete(AuthTokens).where(AuthTokens.id == token.id)
)
await session.commit()
with session_maker() as session:
session.query(AuthTokens).filter(
AuthTokens.id == token.id
).delete()
session.commit()
raise
self.provider_tokens = MappingProxyType(provider_tokens)
@@ -212,7 +210,7 @@ class SaasUserAuth(UserAuth):
if settings_store:
return settings_store
user_id = await self.get_user_id()
settings_store = SaasSettingsStore(user_id, get_config())
settings_store = SaasSettingsStore(user_id, session_maker, get_config())
self.settings_store = settings_store
return settings_store
@@ -280,7 +278,7 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
return None
api_key_store = ApiKeyStore.get_instance()
user_id = await api_key_store.validate_api_key(api_key)
user_id = api_key_store.validate_api_key(api_key)
if not user_id:
return None
offline_token = await token_manager.load_offline_token(user_id)
@@ -329,7 +327,7 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
email_verified = access_token_payload['email_verified']
# Check if email domain is blocked
if email and await domain_blocker.is_domain_blocked(email):
if email and domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for existing user with email: {email}'
)

View File

@@ -38,9 +38,9 @@ from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
from server.config import get_config
from server.logger import logger
from sqlalchemy import String as SQLString
from sqlalchemy import select, type_coerce
from sqlalchemy import type_coerce
from storage.auth_token_store import AuthTokenStore
from storage.database import a_session_maker
from storage.database import session_maker
from storage.github_app_installation import GithubAppInstallation
from storage.offline_token_store import OfflineTokenStore
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
@@ -783,24 +783,25 @@ class TokenManager:
exc_info=True,
)
async def store_org_token(self, installation_id: int, installation_token: str):
def store_org_token(self, installation_id: int, installation_token: str):
"""Store a GitHub App installation token.
Args:
installation_id: GitHub installation ID (integer or string)
installation_token: The token to store
"""
async with a_session_maker() as session:
with session_maker() as session:
# Ensure installation_id is a string
str_installation_id = str(installation_id)
# Use type_coerce to ensure SQLAlchemy treats the parameter as a string
result = await session.execute(
select(GithubAppInstallation).filter(
installation = (
session.query(GithubAppInstallation)
.filter(
GithubAppInstallation.installation_id
== type_coerce(str_installation_id, SQLString)
)
.first()
)
installation = result.scalars().first()
if installation:
installation.encrypted_token = self.encrypt_text(installation_token)
else:
@@ -810,9 +811,9 @@ class TokenManager:
encrypted_token=self.encrypt_text(installation_token),
)
)
await session.commit()
session.commit()
async def load_org_token(self, installation_id: int) -> str | None:
def load_org_token(self, installation_id: int) -> str | None:
"""Load a GitHub App installation token.
Args:
@@ -821,16 +822,17 @@ class TokenManager:
Returns:
The decrypted token if found, None otherwise
"""
async with a_session_maker() as session:
with session_maker() as session:
# Ensure installation_id is a string and use type_coerce
str_installation_id = str(installation_id)
result = await session.execute(
select(GithubAppInstallation).filter(
installation = (
session.query(GithubAppInstallation)
.filter(
GithubAppInstallation.installation_id
== type_coerce(str_installation_id, SQLString)
)
.first()
)
installation = result.scalars().first()
if not installation:
return None
token = self.decrypt_text(installation.encrypted_token)

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from integrations.github.github_manager import GithubManager
from integrations.github.github_view import GithubViewType
from integrations.models import Message, SourceType
from integrations.utils import (
extract_summary_from_conversation_manager,
get_summary_instruction,
@@ -34,12 +35,16 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
send_summary_instruction: bool = True
async def _send_message_to_github(self, message: str) -> None:
"""Send a message to GitHub.
"""
Send a message to GitHub.
Args:
message: The message content to send to GitHub
"""
try:
# Create a message object for GitHub
message_obj = Message(source=SourceType.OPENHANDS, message=message)
# Get the token manager
token_manager = TokenManager()
@@ -48,8 +53,8 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
github_manager = GithubManager(token_manager, GitHubDataCollector())
# Send the message directly as a string
await github_manager.send_message(message, self.github_view)
# Send the message
await github_manager.send_message(message_obj, self.github_view)
logger.info(
f'[GitHub] Sent summary message to {self.github_view.full_repo_name}#{self.github_view.issue_number}'

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from integrations.gitlab.gitlab_manager import GitlabManager
from integrations.gitlab.gitlab_view import GitlabViewType
from integrations.models import Message, SourceType
from integrations.utils import (
extract_summary_from_conversation_manager,
get_summary_instruction,
@@ -13,7 +14,7 @@ from storage.conversation_callback import (
ConversationCallback,
ConversationCallbackProcessor,
)
from storage.database import a_session_maker
from storage.database import session_maker
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
@@ -27,7 +28,8 @@ gitlab_manager = GitlabManager(token_manager)
class GitlabCallbackProcessor(ConversationCallbackProcessor):
"""Processor for sending conversation summaries to GitLab.
"""
Processor for sending conversation summaries to GitLab.
This processor is used to send summaries of conversations to GitLab
when agent state changes occur.
@@ -37,18 +39,22 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
send_summary_instruction: bool = True
async def _send_message_to_gitlab(self, message: str) -> None:
"""Send a message to GitLab.
"""
Send a message to GitLab.
Args:
message: The message content to send to GitLab
"""
try:
# Create a message object for GitHub
message_obj = Message(source=SourceType.OPENHANDS, message=message)
# Get the token manager
token_manager = TokenManager()
gitlab_manager = GitlabManager(token_manager)
# Send the message directly as a string
await gitlab_manager.send_message(message, self.gitlab_view)
# Send the message
await gitlab_manager.send_message(message_obj, self.gitlab_view)
logger.info(
f'[GitLab] Sent summary message to {self.gitlab_view.full_repo_name}#{self.gitlab_view.issue_number}'
@@ -105,9 +111,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
self.send_summary_instruction = False
callback.set_processor(self)
callback.updated_at = datetime.now()
async with a_session_maker() as session:
with session_maker() as session:
session.merge(callback)
await session.commit()
session.commit()
return
# Extract the summary from the event store
@@ -126,9 +132,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
# Mark callback as completed status
callback.status = CallbackStatus.COMPLETED
callback.updated_at = datetime.now()
async with a_session_maker() as session:
with session_maker() as session:
session.merge(callback)
await session.commit()
session.commit()
except Exception as e:
logger.exception(

View File

@@ -37,7 +37,8 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
workspace_name: str
async def _send_comment_to_jira(self, message: str) -> None:
"""Send a comment to Jira issue.
"""
Send a comment to Jira issue.
Args:
message: The message content to send to Jira
@@ -58,9 +59,8 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
# Decrypt API key
api_key = jira_manager.token_manager.decrypt_text(workspace.svc_acc_api_key)
# Send comment directly as a string
await jira_manager.send_message(
message,
jira_manager.create_outgoing_message(msg=message),
issue_key=self.issue_key,
jira_cloud_id=workspace.jira_cloud_id,
svc_acc_email=workspace.svc_acc_email,

View File

@@ -37,7 +37,8 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
base_api_url: str
async def _send_comment_to_jira_dc(self, message: str) -> None:
"""Send a comment to Jira DC issue.
"""
Send a comment to Jira DC issue.
Args:
message: The message content to send to Jira DC
@@ -60,9 +61,8 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
workspace.svc_acc_api_key
)
# Send comment directly as a string
await jira_dc_manager.send_message(
message,
jira_dc_manager.create_outgoing_message(msg=message),
issue_key=self.issue_key,
base_api_url=self.base_api_url,
svc_acc_api_key=api_key,

View File

@@ -36,7 +36,8 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
workspace_name: str
async def _send_comment_to_linear(self, message: str) -> None:
"""Send a comment to Linear issue.
"""
Send a comment to Linear issue.
Args:
message: The message content to send to Linear
@@ -59,9 +60,9 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
workspace.svc_acc_api_key
)
# Send comment directly as a string
# Send comment
await linear_manager.send_message(
message,
linear_manager.create_outgoing_message(msg=message),
self.issue_id,
api_key,
)

View File

@@ -26,7 +26,8 @@ slack_manager = SlackManager(token_manager)
class SlackCallbackProcessor(ConversationCallbackProcessor):
"""Processor for sending conversation summaries to Slack.
"""
Processor for sending conversation summaries to Slack.
This processor is used to send summaries of conversations to Slack channels
when agent state changes occur.
@@ -40,13 +41,14 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
last_user_msg_id: int | None = None
async def _send_message_to_slack(self, message: str) -> None:
"""Send a message to Slack.
"""
Send a message to Slack using the conversation_manager's send_to_event_stream method.
Args:
message: The message content to send to Slack
"""
try:
# Create a message object for Slack view creation (incoming message format)
# Create a message object for Slack
message_obj = Message(
source=SourceType.SLACK,
message={
@@ -65,8 +67,9 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
slack_view = SlackFactory.create_slack_view_from_payload(
message_obj, slack_user, saas_user_auth
)
# Send the message directly as a string
await slack_manager.send_message(message, slack_view)
await slack_manager.send_message(
slack_manager.create_outgoing_message(message), slack_view
)
logger.info(
f'[Slack] Sent summary message to channel {self.channel_id} '

View File

@@ -251,7 +251,7 @@ async def delete_api_key(
)
# Delete the key
success = await api_key_store.delete_api_key_by_id(key_id)
success = api_key_store.delete_api_key_by_id(key_id)
if not success:
raise HTTPException(

View File

@@ -34,8 +34,7 @@ from server.services.org_invitation_service import (
OrgInvitationService,
UserAlreadyMemberError,
)
from sqlalchemy import select
from storage.database import a_session_maker
from storage.database import session_maker
from storage.user import User
from storage.user_store import UserStore
@@ -271,7 +270,7 @@ async def keycloak_callback(
# Fail open - continue with login if reCAPTCHA service unavailable
# Check if email domain is blocked
if email and await domain_blocker.is_domain_blocked(email):
if email and domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
)
@@ -611,20 +610,17 @@ async def accept_tos(request: Request):
# Update user settings with TOS acceptance
accepted_tos: datetime = datetime.now(timezone.utc)
async with a_session_maker() as session:
result = await session.execute(
select(User).where(User.id == uuid.UUID(user_id))
)
user = result.scalar_one_or_none()
with session_maker() as session:
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
if not user:
await session.rollback()
session.rollback()
logger.error('User for {user_id} not found.')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'User does not exist'},
)
user.accepted_tos = accepted_tos
await session.commit()
session.commit()
logger.info(f'User {user_id} accepted TOS')

View File

@@ -11,10 +11,9 @@ from integrations import stripe_service
from pydantic import BaseModel
from server.constants import STRIPE_API_KEY
from server.logger import logger
from sqlalchemy import select
from starlette.datastructures import URL
from storage.billing_session import BillingSession
from storage.database import a_session_maker
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
from storage.subscription_access import SubscriptionAccess
@@ -107,17 +106,16 @@ async def get_subscription_access(
user_id: str = Depends(get_user_id),
) -> SubscriptionAccessResponse | None:
"""Get details of the currently valid subscription for the user."""
async with a_session_maker() as session:
with session_maker() as session:
now = datetime.now(UTC)
result = await session.execute(
select(SubscriptionAccess).where(
SubscriptionAccess.status == 'ACTIVE',
SubscriptionAccess.user_id == user_id,
SubscriptionAccess.start_at <= now,
SubscriptionAccess.end_at >= now,
)
subscription_access = (
session.query(SubscriptionAccess)
.filter(SubscriptionAccess.status == 'ACTIVE')
.filter(SubscriptionAccess.user_id == user_id)
.filter(SubscriptionAccess.start_at <= now)
.filter(SubscriptionAccess.end_at >= now)
.first()
)
subscription_access = result.scalar_one_or_none()
if not subscription_access:
return None
return SubscriptionAccessResponse(
@@ -199,7 +197,7 @@ async def create_checkout_session(
'checkout_session_id': checkout_session.id,
},
)
async with a_session_maker() as session:
with session_maker() as session:
billing_session = BillingSession(
id=checkout_session.id,
user_id=user_id,
@@ -208,7 +206,7 @@ async def create_checkout_session(
price_code='NA',
)
session.add(billing_session)
await session.commit()
session.commit()
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
@@ -217,14 +215,13 @@ async def create_checkout_session(
@billing_router.get('/success')
async def success_callback(session_id: str, request: Request):
# We can't use the auth cookie because of SameSite=strict
async with a_session_maker() as session:
result = await session.execute(
select(BillingSession).where(
BillingSession.id == session_id,
BillingSession.status == 'in_progress',
)
with session_maker() as session:
billing_session = (
session.query(BillingSession)
.filter(BillingSession.id == session_id)
.filter(BillingSession.status == 'in_progress')
.first()
)
billing_session = result.scalar_one_or_none()
if billing_session is None:
# Hopefully this never happens - we get a redirect from stripe where the session does not exist
@@ -256,8 +253,7 @@ async def success_callback(session_id: str, request: Request):
user_team_info, billing_session.user_id, str(user.current_org_id)
)
result = await session.execute(select(Org).where(Org.id == user.current_org_id))
org = result.scalar_one_or_none()
org = session.query(Org).filter(Org.id == user.current_org_id).first()
new_max_budget = max_budget + add_credits
await LiteLlmManager.update_team_and_users_budget(
@@ -283,7 +279,7 @@ async def success_callback(session_id: str, request: Request):
'stripe_customer_id': stripe_session.customer,
},
)
await session.commit()
session.commit()
return RedirectResponse(
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
@@ -293,14 +289,13 @@ async def success_callback(session_id: str, request: Request):
# Callback endpoint for cancelled Stripe payments - updates billing session status
@billing_router.get('/cancel')
async def cancel_callback(session_id: str, request: Request):
async with a_session_maker() as session:
result = await session.execute(
select(BillingSession).where(
BillingSession.id == session_id,
BillingSession.status == 'in_progress',
)
with session_maker() as session:
billing_session = (
session.query(BillingSession)
.filter(BillingSession.id == session_id)
.filter(BillingSession.status == 'in_progress')
.first()
)
billing_session = result.scalar_one_or_none()
if billing_session:
logger.info(
'stripe_checkout_cancel',
@@ -312,7 +307,7 @@ async def cancel_callback(session_id: str, request: Request):
billing_session.status = 'cancelled'
billing_session.updated_at = datetime.now(UTC)
session.merge(billing_session)
await session.commit()
session.commit()
return RedirectResponse(
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.future import select
from storage.database import a_session_maker
from storage.database import session_maker
from storage.feedback import ConversationFeedback
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
@@ -11,6 +11,7 @@ from openhands.events.event_store import EventStore
from openhands.server.dependencies import get_dependencies
from openhands.server.shared import file_store
from openhands.server.user_auth import get_user_id
from openhands.utils.async_utils import call_sync_from_async
# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint
# is protected. The actual protection is provided by SetAuthCookieMiddleware
@@ -36,19 +37,23 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
"""
# Verify the conversation belongs to the user
async with a_session_maker() as session:
result = await session.execute(
select(StoredConversationMetadataSaas).where(
StoredConversationMetadataSaas.conversation_id == conversation_id,
StoredConversationMetadataSaas.user_id == user_id,
)
)
metadata = result.scalars().first()
if not metadata:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Conversation {conversation_id} not found',
def _verify_conversation():
with session_maker() as session:
metadata = (
session.query(StoredConversationMetadataSaas)
.filter(
StoredConversationMetadataSaas.conversation_id == conversation_id,
StoredConversationMetadataSaas.user_id == user_id,
)
.first()
)
if not metadata:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f'Conversation {conversation_id} not found',
)
await call_sync_from_async(_verify_conversation)
# Create an event store to access the events directly
# This works even when the conversation is not running
@@ -98,9 +103,12 @@ async def submit_conversation_feedback(feedback: FeedbackRequest):
)
# Add to database
async with a_session_maker() as session:
session.add(new_feedback)
await session.commit()
def _save_feedback():
with session_maker() as session:
session.add(new_feedback)
session.commit()
await call_sync_from_async(_save_feedback)
return {'status': 'success', 'message': 'Feedback submitted successfully'}
@@ -119,27 +127,30 @@ async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_us
return {}
# Query for existing feedback for all events
async with a_session_maker() as session:
result = await session.execute(
select(ConversationFeedback).where(
ConversationFeedback.conversation_id == conversation_id,
ConversationFeedback.event_id.in_(event_ids),
def _check_feedback():
with session_maker() as session:
result = session.execute(
select(ConversationFeedback).where(
ConversationFeedback.conversation_id == conversation_id,
ConversationFeedback.event_id.in_(event_ids),
)
)
)
# Create a mapping of event_id to feedback
feedback_map = {
feedback.event_id: {
'exists': True,
'rating': feedback.rating,
'reason': feedback.reason,
# Create a mapping of event_id to feedback
feedback_map = {
feedback.event_id: {
'exists': True,
'rating': feedback.rating,
'reason': feedback.reason,
}
for feedback in result.scalars()
}
for feedback in result.scalars()
}
# Build response including all events
response = {}
for event_id in event_ids:
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
# Build response including all events
response = {}
for event_id in event_ids:
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
return response
return response
return await call_sync_from_async(_check_feedback)

View File

@@ -308,11 +308,10 @@ async def jira_events(
logger.info(f'Processing new Jira webhook event: {signature}')
redis_client.setex(key, 300, '1')
# Process the webhook in background after returning response.
# Note: For async functions, BackgroundTasks runs them in the same event loop
# (not a thread pool), so asyncpg connections work correctly.
# Process the webhook
message_payload = {'payload': payload}
message = Message(source=SourceType.JIRA, message=message_payload)
background_tasks.add_task(jira_manager.receive_message, message)
return JSONResponse({'success': True})

View File

@@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from storage.api_key_store import ApiKeyStore
from storage.database import session_maker
from storage.device_code_store import DeviceCodeStore
from openhands.core.logger import openhands_logger as logger
@@ -53,7 +54,7 @@ class DeviceTokenErrorResponse(BaseModel):
# ---------------------------------------------------------------------------
oauth_device_router = APIRouter(prefix='/oauth/device')
device_code_store = DeviceCodeStore()
device_code_store = DeviceCodeStore(session_maker)
# ---------------------------------------------------------------------------
@@ -89,7 +90,7 @@ async def device_authorization(
) -> DeviceAuthorizationResponse:
"""Start device flow by generating device and user codes."""
try:
device_code_entry = await device_code_store.create_device_code(
device_code_entry = device_code_store.create_device_code(
expires_in=DEVICE_CODE_EXPIRES_IN,
)
@@ -124,7 +125,7 @@ async def device_authorization(
async def device_token(device_code: str = Form(...)):
"""Poll for a token until the user authorizes or the code expires."""
try:
device_code_entry = await device_code_store.get_by_device_code(device_code)
device_code_entry = device_code_store.get_by_device_code(device_code)
if not device_code_entry:
return _oauth_error(
@@ -137,9 +138,7 @@ async def device_token(device_code: str = Form(...)):
is_too_fast, current_interval = device_code_entry.check_rate_limit()
if is_too_fast:
# Update poll time and increase interval
await device_code_store.update_poll_time(
device_code, increase_interval=True
)
device_code_store.update_poll_time(device_code, increase_interval=True)
logger.warning(
'Client polling too fast, returning slow_down error',
extra={
@@ -155,7 +154,7 @@ async def device_token(device_code: str = Form(...)):
)
# Update poll time for successful rate limit check
await device_code_store.update_poll_time(device_code, increase_interval=False)
device_code_store.update_poll_time(device_code, increase_interval=False)
if device_code_entry.is_expired():
return _oauth_error(
@@ -182,7 +181,7 @@ async def device_token(device_code: str = Form(...)):
# Retrieve the specific API key for this device using the user_code
api_key_store = ApiKeyStore.get_instance()
device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})'
device_api_key = await api_key_store.retrieve_api_key_by_name(
device_api_key = api_key_store.retrieve_api_key_by_name(
device_code_entry.keycloak_user_id, device_key_name
)
@@ -239,7 +238,7 @@ async def device_verification_authenticated(
)
# Validate device code
device_code_entry = await device_code_store.get_by_user_code(user_code)
device_code_entry = device_code_store.get_by_user_code(user_code)
if not device_code_entry:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -253,7 +252,7 @@ async def device_verification_authenticated(
)
# First, authorize the device code
success = await device_code_store.authorize_device_code(
success = device_code_store.authorize_device_code(
user_code=user_code,
user_id=user_id,
)
@@ -290,7 +289,7 @@ async def device_verification_authenticated(
# Clean up: revert the device authorization since API key creation failed
# This prevents the device from being in an authorized state without an API key
try:
await device_code_store.deny_device_code(user_code)
device_code_store.deny_device_code(user_code)
logger.info(
'Reverted device authorization due to API key creation failure',
extra={'user_code': user_code, 'user_id': user_id},

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, HTTPException, status
from sqlalchemy.sql import text
from storage.database import a_session_maker
from storage.database import session_maker
from storage.redis import create_redis_client
from openhands.core.logger import openhands_logger as logger
@@ -9,11 +9,11 @@ readiness_router = APIRouter()
@readiness_router.get('/ready')
async def is_ready():
def is_ready():
# Check database connection
try:
async with a_session_maker() as session:
await session.execute(text('SELECT 1'))
with session_maker() as session:
session.execute(text('SELECT 1'))
except Exception as e:
logger.error(f'Database check failed: {str(e)}')
raise HTTPException(

View File

@@ -388,4 +388,5 @@ async def _check_idp(
access_token.get_secret_value(), ProviderType(idp)
):
return default_value
return None

View File

@@ -4,14 +4,13 @@ import pickle
from datetime import datetime
from server.logger import logger
from sqlalchemy import and_, select
from storage.conversation_callback import (
CallbackStatus,
ConversationCallback,
ConversationCallbackProcessor,
)
from storage.conversation_work import ConversationWork
from storage.database import a_session_maker, session_maker
from storage.database import session_maker
from storage.stored_conversation_metadata import StoredConversationMetadata
from openhands.core.config import load_openhands_config
@@ -80,16 +79,15 @@ async def invoke_conversation_callbacks(
conversation_id: The conversation ID to process callbacks for
observation: The AgentStateChangedObservation that triggered the callback
"""
async with a_session_maker() as session:
result = await session.execute(
select(ConversationCallback).filter(
and_(
ConversationCallback.conversation_id == conversation_id,
ConversationCallback.status == CallbackStatus.ACTIVE,
)
with session_maker() as session:
callbacks = (
session.query(ConversationCallback)
.filter(
ConversationCallback.conversation_id == conversation_id,
ConversationCallback.status == CallbackStatus.ACTIVE,
)
.all()
)
callbacks = result.scalars().all()
for callback in callbacks:
try:
@@ -117,7 +115,7 @@ async def invoke_conversation_callbacks(
callback.status = CallbackStatus.ERROR
callback.updated_at = datetime.now()
await session.commit()
session.commit()
def update_conversation_metadata(conversation_id: str, content: dict):

View File

@@ -2,10 +2,6 @@
from dataclasses import dataclass
from server.verified_models.verified_model_models import (
VerifiedModel,
VerifiedModelPage,
)
from sqlalchemy import (
Boolean,
Column,
@@ -22,6 +18,10 @@ from sqlalchemy import (
from sqlalchemy.ext.asyncio import AsyncSession
from storage.base import Base
from enterprise.server.verified_models.verified_model_models import (
VerifiedModel,
VerifiedModelPage,
)
from openhands.app_server.config import depends_db_session
from openhands.core.logger import openhands_logger as logger

View File

@@ -5,16 +5,20 @@ import string
from dataclasses import dataclass
from datetime import UTC, datetime
from sqlalchemy import select, update
from sqlalchemy import update
from sqlalchemy.orm import sessionmaker
from storage.api_key import ApiKey
from storage.database import a_session_maker
from storage.database import session_maker
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.utils.async_utils import call_sync_from_async
@dataclass
class ApiKeyStore:
session_maker: sessionmaker
API_KEY_PREFIX = 'sk-oh-'
def generate_api_key(self, length: int = 32) -> str:
@@ -39,8 +43,22 @@ class ApiKeyStore:
api_key = self.generate_api_key()
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
await call_sync_from_async(
self._store_api_key, user_id, org_id, api_key, name, expires_at
)
async with a_session_maker() as session:
return api_key
def _store_api_key(
self,
user_id: str,
org_id: str,
api_key: str,
name: str | None,
expires_at: datetime | None = None,
) -> None:
"""Store an existing API key in the database."""
with self.session_maker() as session:
key_record = ApiKey(
key=api_key,
user_id=user_id,
@@ -49,17 +67,14 @@ class ApiKeyStore:
expires_at=expires_at,
)
session.add(key_record)
await session.commit()
session.commit()
return api_key
async def validate_api_key(self, api_key: str) -> str | None:
def validate_api_key(self, api_key: str) -> str | None:
"""Validate an API key and return the associated user_id if valid."""
now = datetime.now(UTC)
async with a_session_maker() as session:
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
key_record = result.scalars().first()
with self.session_maker() as session:
key_record = session.query(ApiKey).filter(ApiKey.key == api_key).first()
if not key_record:
return None
@@ -76,40 +91,38 @@ class ApiKeyStore:
return None
# Update last_used_at timestamp
await session.execute(
session.execute(
update(ApiKey)
.where(ApiKey.id == key_record.id)
.values(last_used_at=now)
)
await session.commit()
session.commit()
return key_record.user_id
async def delete_api_key(self, api_key: str) -> bool:
def delete_api_key(self, api_key: str) -> bool:
"""Delete an API key by the key value."""
async with a_session_maker() as session:
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
key_record = result.scalars().first()
with self.session_maker() as session:
key_record = session.query(ApiKey).filter(ApiKey.key == api_key).first()
if not key_record:
return False
await session.delete(key_record)
await session.commit()
session.delete(key_record)
session.commit()
return True
async def delete_api_key_by_id(self, key_id: int) -> bool:
def delete_api_key_by_id(self, key_id: int) -> bool:
"""Delete an API key by its ID."""
async with a_session_maker() as session:
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
key_record = result.scalars().first()
with self.session_maker() as session:
key_record = session.query(ApiKey).filter(ApiKey.id == key_id).first()
if not key_record:
return False
await session.delete(key_record)
await session.commit()
session.delete(key_record)
session.commit()
return True
@@ -117,55 +130,64 @@ class ApiKeyStore:
"""List all API keys for a user."""
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
return await call_sync_from_async(self._list_api_keys_from_db, user_id, org_id)
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(
ApiKey.user_id == user_id, ApiKey.org_id == org_id
)
def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]:
with self.session_maker() as session:
keys: list[ApiKey] = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
)
keys = result.scalars().all()
return [key for key in keys if key.name != 'MCP_API_KEY']
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
return await call_sync_from_async(
self._retrieve_mcp_api_key_from_db, user_id, org_id
)
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(
ApiKey.user_id == user_id, ApiKey.org_id == org_id
)
def _retrieve_mcp_api_key_from_db(self, user_id: str, org_id: str) -> str | None:
with self.session_maker() as session:
keys: list[ApiKey] = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
)
keys = result.scalars().all()
for key in keys:
if key.name == 'MCP_API_KEY':
return key.key
return None
async def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
"""Retrieve an API key by name for a specific user."""
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
with self.session_maker() as session:
key_record = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
.first()
)
key_record = result.scalars().first()
return key_record.key if key_record else None
async def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
"""Delete an API key by name for a specific user."""
async with a_session_maker() as session:
result = await session.execute(
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
with self.session_maker() as session:
key_record = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
.first()
)
key_record = result.scalars().first()
if not key_record:
return False
await session.delete(key_record)
await session.commit()
session.delete(key_record)
session.commit()
return True
@@ -173,4 +195,4 @@ class ApiKeyStore:
def get_instance(cls) -> ApiKeyStore:
"""Get an instance of the ApiKeyStore."""
logger.debug('api_key_store.get_instance')
return ApiKeyStore()
return ApiKeyStore(session_maker)

View File

@@ -7,6 +7,7 @@ from typing import Awaitable, Callable, Dict
from server.auth.auth_error import TokenRefreshError
from sqlalchemy import select, text, update
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from storage.auth_tokens import AuthTokens
from storage.database import a_session_maker
@@ -26,6 +27,7 @@ LOCK_TIMEOUT_SECONDS = 5
class AuthTokenStore:
keycloak_user_id: str
idp: ProviderType
a_session_maker: sessionmaker
@property
def identity_provider_value(self) -> str:
@@ -71,7 +73,7 @@ class AuthTokenStore:
access_token_expires_at: Expiration time for access token (seconds since epoch)
refresh_token_expires_at: Expiration time for refresh token (seconds since epoch)
"""
async with a_session_maker() as session:
async with self.a_session_maker() as session:
async with session.begin(): # Explicitly start a transaction
result = await session.execute(
select(AuthTokens).where(
@@ -136,7 +138,7 @@ class AuthTokenStore:
a 401 response to prompt the user to re-authenticate.
"""
# FAST PATH: Check without lock first to avoid unnecessary lock contention
async with a_session_maker() as session:
async with self.a_session_maker() as session:
result = await session.execute(
select(AuthTokens).filter(
AuthTokens.keycloak_user_id == self.keycloak_user_id,
@@ -165,7 +167,7 @@ class AuthTokenStore:
# SLOW PATH: Token needs refresh, acquire lock
try:
async with a_session_maker() as session:
async with self.a_session_maker() as session:
async with session.begin():
# Set a lock timeout to prevent indefinite blocking
# This ensures we don't hold connections forever if something goes wrong
@@ -298,4 +300,6 @@ class AuthTokenStore:
logger.debug(f'auth_token_store.get_instance::{keycloak_user_id}')
if keycloak_user_id:
keycloak_user_id = str(keycloak_user_id)
return AuthTokenStore(keycloak_user_id=keycloak_user_id, idp=idp)
return AuthTokenStore(
keycloak_user_id=keycloak_user_id, idp=idp, a_session_maker=a_session_maker
)

View File

@@ -1,12 +1,14 @@
from dataclasses import dataclass
from sqlalchemy import text
from storage.database import a_session_maker
from sqlalchemy.orm import sessionmaker
@dataclass
class BlockedEmailDomainStore:
async def is_domain_blocked(self, domain: str) -> bool:
session_maker: sessionmaker
def is_domain_blocked(self, domain: str) -> bool:
"""Check if a domain is blocked by querying the database directly.
This method uses SQL to efficiently check if the domain matches any blocked pattern:
@@ -19,9 +21,9 @@ class BlockedEmailDomainStore:
Returns:
True if the domain is blocked, False otherwise
"""
async with a_session_maker() as session:
with self.session_maker() as session:
# SQL query that handles both TLD patterns and full domain patterns
# TLD patterns (starting with '.'): check if domain ends with it (case-insensitive)
# TLD patterns (starting with '.'): check if domain ends with the pattern
# Full domain patterns: check for exact match or subdomain match
# All comparisons are case-insensitive using LOWER() to ensure consistent matching
query = text("""
@@ -39,5 +41,5 @@ class BlockedEmailDomainStore:
))
)
""")
result = await session.execute(query, {'domain': domain})
return bool(result.scalar())
result = session.execute(query, {'domain': domain}).scalar()
return bool(result)

View File

@@ -47,11 +47,7 @@ class DeviceCode(Base):
def is_expired(self) -> bool:
"""Check if the device code has expired."""
now = datetime.now(timezone.utc)
# Handle timezone-naive datetime from database by assuming it's UTC
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
return now > self.expires_at
def is_pending(self) -> bool:
"""Check if the device code is still pending authorization."""
@@ -89,13 +85,8 @@ class DeviceCode(Base):
if self.last_poll_time is None:
return False, self.current_interval
# Handle timezone-naive datetime from database by assuming it's UTC
last_poll_time = self.last_poll_time
if last_poll_time.tzinfo is None:
last_poll_time = last_poll_time.replace(tzinfo=timezone.utc)
# Calculate time since last poll
time_since_last_poll = (now - last_poll_time).total_seconds()
time_since_last_poll = (now - self.last_poll_time).total_seconds()
# Check if polling too fast
if time_since_last_poll < self.current_interval:

View File

@@ -1,20 +1,19 @@
"""Device code store for OAuth 2.0 Device Flow."""
from __future__ import annotations
import secrets
import string
from datetime import datetime, timedelta, timezone
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from storage.database import a_session_maker
from storage.device_code import DeviceCode
class DeviceCodeStore:
"""Store for managing OAuth 2.0 device codes."""
def __init__(self, session_maker):
self.session_maker = session_maker
def generate_user_code(self) -> str:
"""Generate a human-readable user code (8 characters, uppercase letters and digits)."""
# Use a mix of uppercase letters and digits, avoiding confusing characters
@@ -26,7 +25,7 @@ class DeviceCodeStore:
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for _ in range(128))
async def create_device_code(
def create_device_code(
self,
expires_in: int = 600, # 10 minutes default
max_attempts: int = 10,
@@ -59,10 +58,11 @@ class DeviceCodeStore:
)
try:
async with a_session_maker() as session:
with self.session_maker() as session:
session.add(device_code_entry)
await session.commit()
await session.refresh(device_code_entry)
session.commit()
session.refresh(device_code_entry)
session.expunge(device_code_entry) # Detach from session cleanly
return device_code_entry
except IntegrityError:
# Constraint violation - codes already exist, retry with new codes
@@ -72,23 +72,25 @@ class DeviceCodeStore:
f'Failed to generate unique device codes after {max_attempts} attempts'
)
async def get_by_device_code(self, device_code: str) -> DeviceCode | None:
def get_by_device_code(self, device_code: str) -> DeviceCode | None:
"""Get device code entry by device code."""
async with a_session_maker() as session:
result = await session.execute(
select(DeviceCode).filter_by(device_code=device_code)
with self.session_maker() as session:
result = (
session.query(DeviceCode).filter_by(device_code=device_code).first()
)
return result.scalars().first()
if result:
session.expunge(result) # Detach from session cleanly
return result
async def get_by_user_code(self, user_code: str) -> DeviceCode | None:
def get_by_user_code(self, user_code: str) -> DeviceCode | None:
"""Get device code entry by user code."""
async with a_session_maker() as session:
result = await session.execute(
select(DeviceCode).filter_by(user_code=user_code)
)
return result.scalars().first()
with self.session_maker() as session:
result = session.query(DeviceCode).filter_by(user_code=user_code).first()
if result:
session.expunge(result) # Detach from session cleanly
return result
async def authorize_device_code(self, user_code: str, user_id: str) -> bool:
def authorize_device_code(self, user_code: str, user_id: str) -> bool:
"""Authorize a device code.
Args:
@@ -98,11 +100,10 @@ class DeviceCodeStore:
Returns:
True if authorization was successful, False otherwise
"""
async with a_session_maker() as session:
result = await session.execute(
select(DeviceCode).filter_by(user_code=user_code)
with self.session_maker() as session:
device_code_entry = (
session.query(DeviceCode).filter_by(user_code=user_code).first()
)
device_code_entry = result.scalars().first()
if not device_code_entry:
return False
@@ -111,11 +112,11 @@ class DeviceCodeStore:
return False
device_code_entry.authorize(user_id)
await session.commit()
session.commit()
return True
async def deny_device_code(self, user_code: str) -> bool:
def deny_device_code(self, user_code: str) -> bool:
"""Deny a device code authorization.
Args:
@@ -124,11 +125,10 @@ class DeviceCodeStore:
Returns:
True if denial was successful, False otherwise
"""
async with a_session_maker() as session:
result = await session.execute(
select(DeviceCode).filter_by(user_code=user_code)
with self.session_maker() as session:
device_code_entry = (
session.query(DeviceCode).filter_by(user_code=user_code).first()
)
device_code_entry = result.scalars().first()
if not device_code_entry:
return False
@@ -137,11 +137,11 @@ class DeviceCodeStore:
return False
device_code_entry.deny()
await session.commit()
session.commit()
return True
async def update_poll_time(
def update_poll_time(
self, device_code: str, increase_interval: bool = False
) -> bool:
"""Update the poll time for a device code and optionally increase interval.
@@ -153,16 +153,15 @@ class DeviceCodeStore:
Returns:
True if update was successful, False otherwise
"""
async with a_session_maker() as session:
result = await session.execute(
select(DeviceCode).filter_by(device_code=device_code)
with self.session_maker() as session:
device_code_entry = (
session.query(DeviceCode).filter_by(device_code=device_code).first()
)
device_code_entry = result.scalars().first()
if not device_code_entry:
return False
device_code_entry.update_poll_time(increase_interval)
await session.commit()
session.commit()
return True

View File

@@ -5,6 +5,7 @@ from dataclasses import dataclass
from integrations.types import GitLabResourceType
from sqlalchemy import and_, asc, select, text, update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import sessionmaker
from storage.database import a_session_maker
from storage.gitlab_webhook import GitlabWebhook
@@ -13,6 +14,8 @@ from openhands.core.logger import openhands_logger as logger
@dataclass
class GitlabWebhookStore:
a_session_maker: sessionmaker = a_session_maker
@staticmethod
def determine_resource_type(
webhook: GitlabWebhook,
@@ -41,7 +44,7 @@ class GitlabWebhookStore:
if not project_details:
return
async with a_session_maker() as session:
async with self.a_session_maker() as session:
async with session.begin():
# Convert GitlabWebhook objects to dictionaries for the insert
# Using __dict__ and filtering out SQLAlchemy internal attributes and 'id'
@@ -85,7 +88,7 @@ class GitlabWebhookStore:
"""
resource_type, resource_id = GitlabWebhookStore.determine_resource_type(webhook)
async with a_session_maker() as session:
async with self.a_session_maker() as session:
async with session.begin():
stmt = (
update(GitlabWebhook).where(GitlabWebhook.project_id == resource_id)
@@ -119,7 +122,7 @@ class GitlabWebhookStore:
},
)
async with a_session_maker() as session:
async with self.a_session_maker() as session:
async with session.begin():
# Create query based on the identifier provided
if resource_type == GitLabResourceType.PROJECT:
@@ -182,7 +185,7 @@ class GitlabWebhookStore:
List of GitlabWebhook objects that need processing
"""
async with a_session_maker() as session:
async with self.a_session_maker() as session:
query = (
select(GitlabWebhook)
.where(GitlabWebhook.webhook_exists.is_(False))
@@ -198,7 +201,7 @@ class GitlabWebhookStore:
"""
Get's webhook secret given the webhook uuid and admin keycloak user id
"""
async with a_session_maker() as session:
async with self.a_session_maker() as session:
query = (
select(GitlabWebhook)
.where(
@@ -232,7 +235,7 @@ class GitlabWebhookStore:
Returns:
GitlabWebhook object if found, None otherwise
"""
async with a_session_maker() as session:
async with self.a_session_maker() as session:
if resource_type == GitLabResourceType.PROJECT:
query = select(GitlabWebhook).where(
GitlabWebhook.project_id == resource_id
@@ -260,7 +263,7 @@ class GitlabWebhookStore:
Returns:
Tuple of (project_webhook_map, group_webhook_map)
"""
async with a_session_maker() as session:
async with self.a_session_maker() as session:
project_webhook_map = {}
group_webhook_map = {}
@@ -300,7 +303,7 @@ class GitlabWebhookStore:
Returns:
True if webhook was reset, False if not found
"""
async with a_session_maker() as session:
async with self.a_session_maker() as session:
async with session.begin():
if resource_type == GitLabResourceType.PROJECT:
update_statement = (
@@ -345,4 +348,4 @@ class GitlabWebhookStore:
Returns:
An instance of GitlabWebhookStore
"""
return GitlabWebhookStore()
return GitlabWebhookStore(a_session_maker)

View File

@@ -3,8 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from sqlalchemy import select
from storage.database import a_session_maker
from storage.database import session_maker
from storage.jira_dc_conversation import JiraDcConversation
from storage.jira_dc_user import JiraDcUser
from storage.jira_dc_workspace import JiraDcWorkspace
@@ -25,7 +24,7 @@ class JiraDcIntegrationStore:
) -> JiraDcWorkspace:
"""Create a new Jira DC workspace with encrypted sensitive data."""
async with a_session_maker() as session:
with session_maker() as session:
workspace = JiraDcWorkspace(
name=name.lower(),
admin_user_id=admin_user_id,
@@ -35,8 +34,8 @@ class JiraDcIntegrationStore:
status=status,
)
session.add(workspace)
await session.commit()
await session.refresh(workspace)
session.commit()
session.refresh(workspace)
logger.info(f'[Jira DC] Created workspace {workspace.name}')
return workspace
@@ -49,12 +48,11 @@ class JiraDcIntegrationStore:
status: Optional[str] = None,
) -> JiraDcWorkspace:
"""Update an existing Jira DC workspace with encrypted sensitive data."""
async with a_session_maker() as session:
with session_maker() as session:
# Find existing workspace by ID
result = await session.execute(
select(JiraDcWorkspace).where(JiraDcWorkspace.id == id)
workspace = (
session.query(JiraDcWorkspace).filter(JiraDcWorkspace.id == id).first()
)
workspace = result.scalar_one_or_none()
if not workspace:
raise ValueError(f'Workspace with ID "{id}" not found')
@@ -71,8 +69,8 @@ class JiraDcIntegrationStore:
if status is not None:
workspace.status = status
await session.commit()
await session.refresh(workspace)
session.commit()
session.refresh(workspace)
logger.info(f'[Jira DC] Updated workspace {workspace.name}')
return workspace
@@ -93,10 +91,10 @@ class JiraDcIntegrationStore:
status=status,
)
async with a_session_maker() as session:
with session_maker() as session:
session.add(jira_dc_user)
await session.commit()
await session.refresh(jira_dc_user)
session.commit()
session.refresh(jira_dc_user)
logger.info(
f'[Jira DC] Created user {jira_dc_user.id} for workspace {jira_dc_workspace_id}'
@@ -105,91 +103,94 @@ class JiraDcIntegrationStore:
async def get_workspace_by_id(self, workspace_id: int) -> Optional[JiraDcWorkspace]:
"""Retrieve workspace by ID."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcWorkspace).where(JiraDcWorkspace.id == workspace_id)
with session_maker() as session:
return (
session.query(JiraDcWorkspace)
.filter(JiraDcWorkspace.id == workspace_id)
.first()
)
return result.scalar_one_or_none()
async def get_workspace_by_name(
self, workspace_name: str
) -> Optional[JiraDcWorkspace]:
"""Retrieve workspace by name."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcWorkspace).where(
JiraDcWorkspace.name == workspace_name.lower()
)
with session_maker() as session:
return (
session.query(JiraDcWorkspace)
.filter(JiraDcWorkspace.name == workspace_name.lower())
.first()
)
return result.scalar_one_or_none()
async def get_user_by_active_workspace(
self, keycloak_user_id: str
) -> Optional[JiraDcUser]:
"""Retrieve user by Keycloak user ID."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
with session_maker() as session:
return (
session.query(JiraDcUser)
.filter(
JiraDcUser.keycloak_user_id == keycloak_user_id,
JiraDcUser.status == 'active',
)
.first()
)
return result.scalar_one_or_none()
async def get_user_by_keycloak_id_and_workspace(
self, keycloak_user_id: str, jira_dc_workspace_id: int
) -> Optional[JiraDcUser]:
"""Get Jira DC user by Keycloak user ID and workspace ID."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
with session_maker() as session:
return (
session.query(JiraDcUser)
.filter(
JiraDcUser.keycloak_user_id == keycloak_user_id,
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
)
.first()
)
return result.scalar_one_or_none()
async def get_active_user(
self, jira_dc_user_id: str, jira_dc_workspace_id: int
) -> Optional[JiraDcUser]:
"""Get Jira DC user by Keycloak user ID and workspace ID."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
with session_maker() as session:
return (
session.query(JiraDcUser)
.filter(
JiraDcUser.jira_dc_user_id == jira_dc_user_id,
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
JiraDcUser.status == 'active',
)
.first()
)
return result.scalar_one_or_none()
async def get_active_user_by_keycloak_id_and_workspace(
self, keycloak_user_id: str, jira_dc_workspace_id: int
) -> Optional[JiraDcUser]:
"""Get Jira DC user by Keycloak user ID and workspace ID."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
with session_maker() as session:
return (
session.query(JiraDcUser)
.filter(
JiraDcUser.keycloak_user_id == keycloak_user_id,
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
JiraDcUser.status == 'active',
)
.first()
)
return result.scalar_one_or_none()
async def update_user_integration_status(
self, keycloak_user_id: str, status: str
) -> JiraDcUser:
"""Update the status of a Jira DC user mapping."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
JiraDcUser.keycloak_user_id == keycloak_user_id
)
with session_maker() as session:
user = (
session.query(JiraDcUser)
.filter(JiraDcUser.keycloak_user_id == keycloak_user_id)
.first()
)
user = result.scalar_one_or_none()
if not user:
raise ValueError(
@@ -197,35 +198,37 @@ class JiraDcIntegrationStore:
)
user.status = status
await session.commit()
await session.refresh(user)
session.commit()
session.refresh(user)
logger.info(f'[Jira DC] Updated user {keycloak_user_id} status to {status}')
return user
async def deactivate_workspace(self, workspace_id: int):
"""Deactivate the workspace and all user links for a given workspace."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcUser).where(
with session_maker() as session:
users = (
session.query(JiraDcUser)
.filter(
JiraDcUser.jira_dc_workspace_id == workspace_id,
JiraDcUser.status == 'active',
)
.all()
)
users = result.scalars().all()
for user in users:
user.status = 'inactive'
session.add(user)
result = await session.execute(
select(JiraDcWorkspace).where(JiraDcWorkspace.id == workspace_id)
workspace = (
session.query(JiraDcWorkspace)
.filter(JiraDcWorkspace.id == workspace_id)
.first()
)
workspace = result.scalar_one_or_none()
if workspace:
workspace.status = 'inactive'
session.add(workspace)
await session.commit()
session.commit()
logger.info(
f'[Jira DC] Deactivated all user links for workspace {workspace_id}'
@@ -235,22 +238,23 @@ class JiraDcIntegrationStore:
self, jira_dc_conversation: JiraDcConversation
) -> None:
"""Create a new Jira DC conversation record."""
async with a_session_maker() as session:
with session_maker() as session:
session.add(jira_dc_conversation)
await session.commit()
session.commit()
async def get_user_conversations_by_issue_id(
self, issue_id: str, jira_dc_user_id: int
) -> JiraDcConversation | None:
"""Get a Jira DC conversation by issue ID and jira dc user ID."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraDcConversation).where(
with session_maker() as session:
return (
session.query(JiraDcConversation)
.filter(
JiraDcConversation.issue_id == issue_id,
JiraDcConversation.jira_dc_user_id == jira_dc_user_id,
)
.first()
)
return result.scalar_one_or_none()
@classmethod
def get_instance(cls) -> JiraDcIntegrationStore:

View File

@@ -3,8 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from sqlalchemy import and_, select
from storage.database import a_session_maker
from storage.database import session_maker
from storage.jira_conversation import JiraConversation
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
@@ -36,10 +35,10 @@ class JiraIntegrationStore:
status=status,
)
async with a_session_maker() as session:
with session_maker() as session:
session.add(workspace)
await session.commit()
await session.refresh(workspace)
session.commit()
session.refresh(workspace)
logger.info(f'[Jira] Created workspace {workspace.name}')
return workspace
@@ -54,12 +53,11 @@ class JiraIntegrationStore:
status: Optional[str] = None,
) -> JiraWorkspace:
"""Update an existing Jira workspace with encrypted sensitive data."""
async with a_session_maker() as session:
with session_maker() as session:
# Find existing workspace by ID
result = await session.execute(
select(JiraWorkspace).filter(JiraWorkspace.id == id)
workspace = (
session.query(JiraWorkspace).filter(JiraWorkspace.id == id).first()
)
workspace = result.scalars().first()
if not workspace:
raise ValueError(f'Workspace with ID "{id}" not found')
@@ -79,11 +77,11 @@ class JiraIntegrationStore:
if status is not None:
workspace.status = status
await session.commit()
await session.refresh(workspace)
session.commit()
session.refresh(workspace)
logger.info(f'[Jira] Updated workspace {workspace.name}')
return workspace
logger.info(f'[Jira] Updated workspace {workspace.name}')
return workspace
async def create_workspace_link(
self,
@@ -101,10 +99,10 @@ class JiraIntegrationStore:
status=status,
)
async with a_session_maker() as session:
with session_maker() as session:
session.add(jira_user)
await session.commit()
await session.refresh(jira_user)
session.commit()
session.refresh(jira_user)
logger.info(
f'[Jira] Created user {jira_user.id} for workspace {jira_workspace_id}'
@@ -113,77 +111,75 @@ class JiraIntegrationStore:
async def get_workspace_by_id(self, workspace_id: int) -> Optional[JiraWorkspace]:
"""Retrieve workspace by ID."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraWorkspace).filter(JiraWorkspace.id == workspace_id)
with session_maker() as session:
return (
session.query(JiraWorkspace)
.filter(JiraWorkspace.id == workspace_id)
.first()
)
return result.scalars().first()
async def get_workspace_by_name(self, workspace_name: str) -> JiraWorkspace | None:
"""Retrieve workspace by name."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraWorkspace).filter(
JiraWorkspace.name == workspace_name.lower()
)
with session_maker() as session:
return (
session.query(JiraWorkspace)
.filter(JiraWorkspace.name == workspace_name.lower())
.first()
)
return result.scalars().first()
async def get_user_by_active_workspace(
self, keycloak_user_id: str
) -> Optional[JiraUser]:
"""Get Jira user by Keycloak user ID."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraUser).filter(
and_(
JiraUser.keycloak_user_id == keycloak_user_id,
JiraUser.status == 'active',
)
with session_maker() as session:
return (
session.query(JiraUser)
.filter(
JiraUser.keycloak_user_id == keycloak_user_id,
JiraUser.status == 'active',
)
.first()
)
return result.scalars().first()
async def get_user_by_keycloak_id_and_workspace(
self, keycloak_user_id: str, jira_workspace_id: int
) -> Optional[JiraUser]:
"""Get Jira user by Keycloak user ID and workspace ID."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraUser).filter(
and_(
JiraUser.keycloak_user_id == keycloak_user_id,
JiraUser.jira_workspace_id == jira_workspace_id,
)
with session_maker() as session:
return (
session.query(JiraUser)
.filter(
JiraUser.keycloak_user_id == keycloak_user_id,
JiraUser.jira_workspace_id == jira_workspace_id,
)
.first()
)
return result.scalars().first()
async def get_active_user(
self, jira_user_id: str, jira_workspace_id: int
) -> Optional[JiraUser]:
"""Get Jira user by Keycloak user ID and workspace ID."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraUser).filter(
and_(
JiraUser.jira_user_id == jira_user_id,
JiraUser.jira_workspace_id == jira_workspace_id,
JiraUser.status == 'active',
)
with session_maker() as session:
return (
session.query(JiraUser)
.filter(
JiraUser.jira_user_id == jira_user_id,
JiraUser.jira_workspace_id == jira_workspace_id,
JiraUser.status == 'active',
)
.first()
)
return result.scalars().first()
async def update_user_integration_status(
self, keycloak_user_id: str, status: str
) -> JiraUser:
"""Update Jira user integration status."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraUser).filter(JiraUser.keycloak_user_id == keycloak_user_id)
with session_maker() as session:
jira_user = (
session.query(JiraUser)
.filter(JiraUser.keycloak_user_id == keycloak_user_id)
.first()
)
jira_user = result.scalars().first()
if not jira_user:
raise ValueError(
@@ -191,61 +187,60 @@ class JiraIntegrationStore:
)
jira_user.status = status
await session.commit()
await session.refresh(jira_user)
session.commit()
session.refresh(jira_user)
logger.info(f'[Jira] Updated user {keycloak_user_id} status to {status}')
return jira_user
async def deactivate_workspace(self, workspace_id: int):
"""Deactivate the workspace and all user links for a given workspace."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraUser).filter(
and_(
JiraUser.jira_workspace_id == workspace_id,
JiraUser.status == 'active',
)
with session_maker() as session:
users = (
session.query(JiraUser)
.filter(
JiraUser.jira_workspace_id == workspace_id,
JiraUser.status == 'active',
)
.all()
)
users = result.scalars().all()
for user in users:
user.status = 'inactive'
session.add(user)
result = await session.execute(
select(JiraWorkspace).filter(JiraWorkspace.id == workspace_id)
workspace = (
session.query(JiraWorkspace)
.filter(JiraWorkspace.id == workspace_id)
.first()
)
workspace = result.scalars().first()
if workspace:
workspace.status = 'inactive'
session.add(workspace)
await session.commit()
session.commit()
logger.info(f'[Jira] Deactivated all user links for workspace {workspace_id}')
async def create_conversation(self, jira_conversation: JiraConversation) -> None:
"""Create a new Jira conversation record."""
async with a_session_maker() as session:
with session_maker() as session:
session.add(jira_conversation)
await session.commit()
session.commit()
async def get_user_conversations_by_issue_id(
self, issue_id: str, jira_user_id: int
) -> JiraConversation | None:
"""Get a Jira conversation by issue ID and jira user ID."""
async with a_session_maker() as session:
result = await session.execute(
select(JiraConversation).filter(
and_(
JiraConversation.issue_id == issue_id,
JiraConversation.jira_user_id == jira_user_id,
)
with session_maker() as session:
return (
session.query(JiraConversation)
.filter(
JiraConversation.issue_id == issue_id,
JiraConversation.jira_user_id == jira_user_id,
)
.first()
)
return result.scalars().first()
@classmethod
def get_instance(cls) -> JiraIntegrationStore:

View File

@@ -3,8 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from sqlalchemy import select
from storage.database import a_session_maker
from storage.database import session_maker
from storage.linear_conversation import LinearConversation
from storage.linear_user import LinearUser
from storage.linear_workspace import LinearWorkspace
@@ -36,10 +35,10 @@ class LinearIntegrationStore:
status=status,
)
async with a_session_maker() as session:
with session_maker() as session:
session.add(workspace)
await session.commit()
await session.refresh(workspace)
session.commit()
session.refresh(workspace)
logger.info(f'[Linear] Created workspace {workspace.name}')
return workspace
@@ -54,12 +53,11 @@ class LinearIntegrationStore:
status: Optional[str] = None,
) -> LinearWorkspace:
"""Update an existing Linear workspace with encrypted sensitive data."""
async with a_session_maker() as session:
with session_maker() as session:
# Find existing workspace by ID
result = await session.execute(
select(LinearWorkspace).where(LinearWorkspace.id == id)
workspace = (
session.query(LinearWorkspace).filter(LinearWorkspace.id == id).first()
)
workspace = result.scalar_one_or_none()
if not workspace:
raise ValueError(f'Workspace with ID "{id}" not found')
@@ -79,8 +77,8 @@ class LinearIntegrationStore:
if status is not None:
workspace.status = status
await session.commit()
await session.refresh(workspace)
session.commit()
session.refresh(workspace)
logger.info(f'[Linear] Updated workspace {workspace.name}')
return workspace
@@ -100,10 +98,10 @@ class LinearIntegrationStore:
status=status,
)
async with a_session_maker() as session:
with session_maker() as session:
session.add(linear_user)
await session.commit()
await session.refresh(linear_user)
session.commit()
session.refresh(linear_user)
logger.info(
f'[Linear] Created user {linear_user.id} for workspace {linear_workspace_id}'
@@ -112,75 +110,77 @@ class LinearIntegrationStore:
async def get_workspace_by_id(self, workspace_id: int) -> Optional[LinearWorkspace]:
"""Retrieve workspace by ID."""
async with a_session_maker() as session:
result = await session.execute(
select(LinearWorkspace).where(LinearWorkspace.id == workspace_id)
with session_maker() as session:
return (
session.query(LinearWorkspace)
.filter(LinearWorkspace.id == workspace_id)
.first()
)
return result.scalar_one_or_none()
async def get_workspace_by_name(
self, workspace_name: str
) -> Optional[LinearWorkspace]:
"""Retrieve workspace by name."""
async with a_session_maker() as session:
result = await session.execute(
select(LinearWorkspace).where(
LinearWorkspace.name == workspace_name.lower()
)
with session_maker() as session:
return (
session.query(LinearWorkspace)
.filter(LinearWorkspace.name == workspace_name.lower())
.first()
)
return result.scalar_one_or_none()
async def get_user_by_active_workspace(
self, keycloak_user_id: str
) -> LinearUser | None:
"""Get Linear user by Keycloak user ID."""
async with a_session_maker() as session:
result = await session.execute(
select(LinearUser).where(
with session_maker() as session:
return (
session.query(LinearUser)
.filter(
LinearUser.keycloak_user_id == keycloak_user_id,
LinearUser.status == 'active',
)
.first()
)
return result.scalar_one_or_none()
async def get_user_by_keycloak_id_and_workspace(
self, keycloak_user_id: str, linear_workspace_id: int
) -> Optional[LinearUser]:
"""Get Linear user by Keycloak user ID and workspace ID."""
async with a_session_maker() as session:
result = await session.execute(
select(LinearUser).where(
with session_maker() as session:
return (
session.query(LinearUser)
.filter(
LinearUser.keycloak_user_id == keycloak_user_id,
LinearUser.linear_workspace_id == linear_workspace_id,
)
.first()
)
return result.scalar_one_or_none()
async def get_active_user(
self, linear_user_id: str, linear_workspace_id: int
) -> Optional[LinearUser]:
"""Get Linear user by Keycloak user ID and workspace ID."""
async with a_session_maker() as session:
result = await session.execute(
select(LinearUser).where(
with session_maker() as session:
return (
session.query(LinearUser)
.filter(
LinearUser.linear_user_id == linear_user_id,
LinearUser.linear_workspace_id == linear_workspace_id,
LinearUser.status == 'active',
)
.first()
)
return result.scalar_one_or_none()
async def update_user_integration_status(
self, keycloak_user_id: str, status: str
) -> LinearUser:
"""Update Linear user integration status."""
async with a_session_maker() as session:
result = await session.execute(
select(LinearUser).where(
LinearUser.keycloak_user_id == keycloak_user_id
)
with session_maker() as session:
linear_user = (
session.query(LinearUser)
.filter(LinearUser.keycloak_user_id == keycloak_user_id)
.first()
)
linear_user = result.scalar_one_or_none()
if not linear_user:
raise ValueError(
@@ -188,36 +188,38 @@ class LinearIntegrationStore:
)
linear_user.status = status
await session.commit()
await session.refresh(linear_user)
session.commit()
session.refresh(linear_user)
logger.info(f'[Linear] Updated user {keycloak_user_id} status to {status}')
return linear_user
async def deactivate_workspace(self, workspace_id: int):
"""Deactivate the workspace and all user links for a given workspace."""
async with a_session_maker() as session:
result = await session.execute(
select(LinearUser).where(
with session_maker() as session:
users = (
session.query(LinearUser)
.filter(
LinearUser.linear_workspace_id == workspace_id,
LinearUser.status == 'active',
)
.all()
)
users = result.scalars().all()
for user in users:
user.status = 'inactive'
session.add(user)
result = await session.execute(
select(LinearWorkspace).where(LinearWorkspace.id == workspace_id)
workspace = (
session.query(LinearWorkspace)
.filter(LinearWorkspace.id == workspace_id)
.first()
)
workspace = result.scalar_one_or_none()
if workspace:
workspace.status = 'inactive'
session.add(workspace)
await session.commit()
session.commit()
logger.info(f'[Jira] Deactivated all user links for workspace {workspace_id}')
@@ -225,22 +227,23 @@ class LinearIntegrationStore:
self, linear_conversation: LinearConversation
) -> None:
"""Create a new Linear conversation record."""
async with a_session_maker() as session:
with session_maker() as session:
session.add(linear_conversation)
await session.commit()
session.commit()
async def get_user_conversations_by_issue_id(
self, issue_id: str, linear_user_id: int
) -> LinearConversation | None:
"""Get a Linear conversation by issue ID and linear user ID."""
async with a_session_maker() as session:
result = await session.execute(
select(LinearConversation).where(
with session_maker() as session:
return (
session.query(LinearConversation)
.filter(
LinearConversation.issue_id == issue_id,
LinearConversation.linear_user_id == linear_user_id,
)
.first()
)
return result.scalar_one_or_none()
@classmethod
def get_instance(cls) -> LinearIntegrationStore:

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
from dataclasses import dataclass
from sqlalchemy import select
from storage.database import a_session_maker
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from storage.stored_offline_token import StoredOfflineToken
from openhands.core.config.openhands_config import OpenHandsConfig
@@ -13,17 +13,17 @@ from openhands.core.logger import openhands_logger as logger
@dataclass
class OfflineTokenStore:
user_id: str
session_maker: sessionmaker
config: OpenHandsConfig
async def store_token(self, offline_token: str) -> None:
"""Store an offline token in the database."""
async with a_session_maker() as session:
result = await session.execute(
select(StoredOfflineToken).where(
StoredOfflineToken.user_id == self.user_id
)
with self.session_maker() as session:
token_record = (
session.query(StoredOfflineToken)
.filter(StoredOfflineToken.user_id == self.user_id)
.first()
)
token_record = result.scalar_one_or_none()
if token_record:
token_record.offline_token = offline_token
@@ -32,17 +32,16 @@ class OfflineTokenStore:
user_id=self.user_id, offline_token=offline_token
)
session.add(token_record)
await session.commit()
session.commit()
async def load_token(self) -> str | None:
"""Load an offline token from the database."""
async with a_session_maker() as session:
result = await session.execute(
select(StoredOfflineToken).where(
StoredOfflineToken.user_id == self.user_id
)
with self.session_maker() as session:
token_record = (
session.query(StoredOfflineToken)
.filter(StoredOfflineToken.user_id == self.user_id)
.first()
)
token_record = result.scalar_one_or_none()
if not token_record:
return None
@@ -57,4 +56,4 @@ class OfflineTokenStore:
logger.debug(f'offline_token_store.get_instance::{user_id}')
if user_id:
user_id = str(user_id)
return OfflineTokenStore(user_id, config)
return OfflineTokenStore(user_id, session_maker, config)

View File

@@ -3,7 +3,6 @@ Service class for managing organization operations.
Separates business logic from route handlers.
"""
from typing import NoReturn
from uuid import UUID, uuid4
from uuid import UUID as parse_uuid
@@ -326,7 +325,7 @@ class OrgService:
user_id: str,
original_error: Exception,
error_message: str,
) -> NoReturn:
) -> None:
"""
Handle failure by cleaning up LiteLLM resources and raising appropriate error.

View File

@@ -10,6 +10,7 @@ from integrations.github.github_types import (
WorkflowRunStatus,
)
from sqlalchemy import and_, delete, select, update
from sqlalchemy.orm import sessionmaker
from storage.database import a_session_maker
from storage.proactive_convos import ProactiveConversation
@@ -19,6 +20,8 @@ from openhands.integrations.service_types import ProviderType
@dataclass
class ProactiveConversationStore:
a_session_maker: sessionmaker = a_session_maker
def get_repo_id(self, provider: ProviderType, repo_id):
return f'{provider.value}##{repo_id}'
@@ -48,7 +51,7 @@ class ProactiveConversationStore:
final_workflow_group = None
async with a_session_maker() as session:
async with self.a_session_maker() as session:
# Start an explicit transaction with row-level locking
async with session.begin():
# Get the existing proactive conversation entry with FOR UPDATE lock
@@ -139,7 +142,7 @@ class ProactiveConversationStore:
# Calculate the cutoff time (current time - older_than_minutes)
cutoff_time = datetime.now(UTC) - timedelta(minutes=older_than_minutes)
async with a_session_maker() as session:
async with self.a_session_maker() as session:
async with session.begin():
# Delete records older than the cutoff time
delete_stmt = delete(ProactiveConversation).where(
@@ -155,9 +158,9 @@ class ProactiveConversationStore:
@classmethod
async def get_instance(cls) -> ProactiveConversationStore:
"""Get an instance of the ProactiveConversationStore.
"""Get an instance of the GitlabWebhookStore.
Returns:
An instance of ProactiveConversationStore
An instance of GitlabWebhookStore
"""
return ProactiveConversationStore()
return ProactiveConversationStore(a_session_maker)

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
from dataclasses import dataclass
from sqlalchemy import select
from storage.database import a_session_maker
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from storage.stored_repository import StoredRepository
from openhands.core.config.openhands_config import OpenHandsConfig
@@ -11,11 +11,12 @@ from openhands.core.config.openhands_config import OpenHandsConfig
@dataclass
class RepositoryStore:
session_maker: sessionmaker
config: OpenHandsConfig
async def store_projects(self, repositories: list[StoredRepository]) -> None:
def store_projects(self, repositories: list[StoredRepository]) -> None:
"""
Store repositories in database (async version)
Store repositories in database
1. Make sure to store repositories if its ID doesn't exist
2. If repository ID already exists, make sure to only update the repo is_public and repo_name fields
@@ -25,15 +26,17 @@ class RepositoryStore:
if not repositories:
return
async with a_session_maker() as session:
with self.session_maker() as session:
# Extract all repo_ids to check
repo_ids = [r.repo_id for r in repositories]
# Get all existing repositories in a single query
result = await session.execute(
select(StoredRepository).filter(StoredRepository.repo_id.in_(repo_ids))
)
existing_repos = {r.repo_id: r for r in result.scalars().all()}
existing_repos = {
r.repo_id: r
for r in session.query(StoredRepository).filter(
StoredRepository.repo_id.in_(repo_ids)
)
}
# Process all repositories
for repo in repositories:
@@ -47,9 +50,9 @@ class RepositoryStore:
session.add(repo)
# Commit all changes
await session.commit()
session.commit()
@classmethod
def get_instance(cls, config: OpenHandsConfig) -> RepositoryStore:
"""Get an instance of the UserRepositoryStore."""
return RepositoryStore(config)
return RepositoryStore(session_maker, config)

View File

@@ -234,8 +234,6 @@ class SaasConversationStore(ConversationStore):
cls, config: OpenHandsConfig, user_id: str | None
) -> ConversationStore:
# user_id should not be None in SaaS, should we raise?
# Use async version since callers now use asyncio.run_coroutine_threadsafe()
# to dispatch to the main event loop where asyncpg connections work properly.
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id if user else None
return SaasConversationStore(str(user_id), org_id, session_maker)

View File

@@ -28,7 +28,7 @@ class SaasConversationValidator(ConversationValidator):
# Validate the API key and get the user_id
api_key_store = ApiKeyStore.get_instance()
user_id = await api_key_store.validate_api_key(api_key)
user_id = api_key_store.validate_api_key(api_key)
if not user_id:
logger.warning('Invalid API key')

View File

@@ -5,8 +5,8 @@ from base64 import b64decode, b64encode
from dataclasses import dataclass
from cryptography.fernet import Fernet
from sqlalchemy import delete, select
from storage.database import a_session_maker
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from storage.stored_custom_secrets import StoredCustomSecrets
from storage.user_store import UserStore
@@ -19,6 +19,7 @@ from openhands.storage.secrets.secrets_store import SecretsStore
@dataclass
class SaasSecretsStore(SecretsStore):
user_id: str
session_maker: sessionmaker
config: OpenHandsConfig
async def load(self) -> Secrets | None:
@@ -27,15 +28,14 @@ class SaasSecretsStore(SecretsStore):
user = await UserStore.get_user_by_id_async(self.user_id)
org_id = user.current_org_id if user else None
async with a_session_maker() as session:
with self.session_maker() as session:
# Fetch all secrets for the given user ID
query = select(StoredCustomSecrets).filter(
query = session.query(StoredCustomSecrets).filter(
StoredCustomSecrets.keycloak_user_id == self.user_id
)
if org_id is not None:
query = query.filter(StoredCustomSecrets.org_id == org_id)
result = await session.execute(query)
settings = result.scalars().all()
settings = query.all()
if not settings:
return Secrets()
@@ -54,15 +54,12 @@ class SaasSecretsStore(SecretsStore):
async def store(self, item: Secrets):
user = await UserStore.get_user_by_id_async(self.user_id)
org_id = user.current_org_id
async with a_session_maker() as session:
with self.session_maker() as session:
# Incoming secrets are always the most updated ones
# Delete all existing records and override with incoming ones
await session.execute(
delete(StoredCustomSecrets).filter(
StoredCustomSecrets.keycloak_user_id == self.user_id
)
)
session.query(StoredCustomSecrets).filter(
StoredCustomSecrets.keycloak_user_id == self.user_id
).delete()
# Prepare the new secrets data
kwargs = item.model_dump(context={'expose_secrets': True})
@@ -92,7 +89,7 @@ class SaasSecretsStore(SecretsStore):
)
session.add(new_secret)
await session.commit()
session.commit()
def _decrypt_kwargs(self, kwargs: dict):
fernet = self._fernet()
@@ -136,4 +133,4 @@ class SaasSecretsStore(SecretsStore):
if not user_id:
raise Exception('SaasSecretsStore cannot be constructed with no user_id')
logger.debug(f'saas_secrets_store.get_instance::{user_id}')
return SaasSecretsStore(user_id, config)
return SaasSecretsStore(user_id, session_maker, config)

View File

@@ -10,9 +10,8 @@ from cryptography.fernet import Fernet
from pydantic import SecretStr
from server.constants import LITE_LLM_API_URL
from server.logger import logger
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker
from sqlalchemy.orm import joinedload, sessionmaker
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager, get_openhands_cloud_key_alias
from storage.org import Org
from storage.org_member import OrgMember
@@ -24,24 +23,26 @@ from storage.user_store import UserStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.server.settings import Settings
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.llm import is_openhands_model
@dataclass
class SaasSettingsStore(SettingsStore):
user_id: str
session_maker: sessionmaker
config: OpenHandsConfig
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
async def _get_user_settings_by_keycloak_id_async(
def _get_user_settings_by_keycloak_id(
self, keycloak_user_id: str, session=None
) -> UserSettings | None:
"""
Get UserSettings by keycloak_user_id (async version).
Get UserSettings by keycloak_user_id.
Args:
keycloak_user_id: The keycloak user ID to search for
session: Optional existing async database session. If not provided, creates a new one.
session: Optional existing database session. If not provided, creates a new one.
Returns:
UserSettings object if found, None otherwise
@@ -49,26 +50,27 @@ class SaasSettingsStore(SettingsStore):
if not keycloak_user_id:
return None
if session:
# Use provided session
result = await session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == keycloak_user_id
def _get_settings():
if session:
# Use provided session
return (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
.first()
)
)
return result.scalars().first()
else:
# Create new session
async with a_session_maker() as new_session:
result = await new_session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == keycloak_user_id
else:
# Create new session
with self.session_maker() as new_session:
return (
new_session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
.first()
)
)
return result.scalars().first()
return _get_settings()
async def load(self) -> Settings | None:
user = await UserStore.get_user_by_id_async(self.user_id)
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
if not user:
logger.error(f'User not found for ID {self.user_id}')
return None
@@ -81,7 +83,7 @@ class SaasSettingsStore(SettingsStore):
break
if not org_member or not org_member.llm_api_key:
return None
org = await OrgStore.get_org_by_id_async(org_id)
org = OrgStore.get_org_by_id(org_id)
if not org:
logger.error(
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
@@ -120,22 +122,21 @@ class SaasSettingsStore(SettingsStore):
return settings
async def store(self, item: Settings):
async with a_session_maker() as session:
with self.session_maker() as session:
if not item:
return None
result = await session.execute(
select(User)
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(self.user_id))
)
user = result.scalars().first()
).first()
if not user:
# Check if we need to migrate from user_settings
user_settings = None
async with a_session_maker() as new_session:
user_settings = await self._get_user_settings_by_keycloak_id_async(
self.user_id, new_session
with session_maker() as session:
user_settings = self._get_user_settings_by_keycloak_id(
self.user_id, session
)
if user_settings:
user = await UserStore.migrate_user(self.user_id, user_settings)
@@ -153,8 +154,7 @@ class SaasSettingsStore(SettingsStore):
if not org_member or not org_member.llm_api_key:
return None
result = await session.execute(select(Org).filter(Org.id == org_id))
org = result.scalars().first()
org: Org = session.query(Org).filter(Org.id == org_id).first()
if not org:
logger.error(
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
@@ -173,7 +173,7 @@ class SaasSettingsStore(SettingsStore):
if hasattr(model, key):
setattr(model, key, value)
await session.commit()
session.commit()
@classmethod
async def get_instance(
@@ -182,7 +182,7 @@ class SaasSettingsStore(SettingsStore):
user_id: str, # type: ignore[override]
) -> SaasSettingsStore:
logger.debug(f'saas_settings_store.get_instance::{user_id}')
return SaasSettingsStore(user_id, config)
return SaasSettingsStore(user_id, session_maker, config)
def _should_encrypt(self, key):
return key in self.ENCRYPT_VALUES

View File

@@ -2,35 +2,38 @@ from __future__ import annotations
from dataclasses import dataclass
from sqlalchemy import select
from storage.database import a_session_maker
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from storage.slack_conversation import SlackConversation
@dataclass
class SlackConversationStore:
session_maker: sessionmaker
async def get_slack_conversation(
self, channel_id: str, parent_id: str
) -> SlackConversation | None:
"""Get a slack conversation by channel_id and message_ts.
Both parameters are required to match for a conversation to be returned.
"""
async with a_session_maker() as session:
result = await session.execute(
select(SlackConversation).where(
SlackConversation.channel_id == channel_id,
SlackConversation.parent_id == parent_id,
)
with session_maker() as session:
conversation = (
session.query(SlackConversation)
.filter(SlackConversation.channel_id == channel_id)
.filter(SlackConversation.parent_id == parent_id)
.first()
)
return result.scalar_one_or_none()
return conversation
async def create_slack_conversation(
self, slack_converstion: SlackConversation
) -> None:
async with a_session_maker() as session:
with self.session_maker() as session:
session.merge(slack_converstion)
await session.commit()
session.commit()
@classmethod
def get_instance(cls) -> SlackConversationStore:
return SlackConversationStore()
return SlackConversationStore(session_maker)

View File

@@ -32,7 +32,6 @@ class SlackTeamStore:
# Store the token
session.add(slack_team)
session.commit()
return slack_team
@classmethod
def get_instance(cls):

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass
import sqlalchemy
from sqlalchemy import select
from storage.database import a_session_maker
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from storage.user_repo_map import UserRepositoryMap
from openhands.core.config.openhands_config import OpenHandsConfig
@@ -12,11 +12,12 @@ from openhands.core.config.openhands_config import OpenHandsConfig
@dataclass
class UserRepositoryMapStore:
session_maker: sessionmaker
config: OpenHandsConfig
async def store_user_repo_mappings(self, mappings: list[UserRepositoryMap]) -> None:
def store_user_repo_mappings(self, mappings: list[UserRepositoryMap]) -> None:
"""
Store user-repository mappings in database (async version)
Store user-repository mappings in database
1. Make sure to store mappings if they don't exist
2. If a mapping already exists (same user_id and repo_id), update the admin field
@@ -29,20 +30,18 @@ class UserRepositoryMapStore:
if not mappings:
return
async with a_session_maker() as session:
with self.session_maker() as session:
# Extract all user_id/repo_id pairs to check
mapping_keys = [(m.user_id, m.repo_id) for m in mappings]
# Get all existing mappings in a single query
result = await session.execute(
select(UserRepositoryMap).filter(
existing_mappings = {
(m.user_id, m.repo_id): m
for m in session.query(UserRepositoryMap).filter(
sqlalchemy.tuple_(
UserRepositoryMap.user_id, UserRepositoryMap.repo_id
).in_(mapping_keys)
)
)
existing_mappings = {
(m.user_id, m.repo_id): m for m in result.scalars().all()
}
# Process all mappings
@@ -57,9 +56,9 @@ class UserRepositoryMapStore:
session.add(mapping)
# Commit all changes
await session.commit()
session.commit()
@classmethod
def get_instance(cls, config: OpenHandsConfig) -> UserRepositoryMapStore:
"""Get an instance of the UserRepositoryMapStore."""
return UserRepositoryMapStore(config)
return UserRepositoryMapStore(session_maker, config)

View File

@@ -227,7 +227,7 @@ class UserStore:
'user_store:migrate_user:calling_stripe_migrate_customer',
extra={'user_id': user_id},
)
await migrate_customer(user_id, org)
await migrate_customer(session, user_id, org)
logger.debug(
'user_store:migrate_user:done_stripe_migrate_customer',
extra={'user_id': user_id},

View File

@@ -8,16 +8,10 @@ from server.verified_models.verified_model_service import (
StoredVerifiedModel, # noqa: F401
)
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import sessionmaker
from storage.base import Base
# Anything not loaded here may not have a table created for it.
from storage.api_key import ApiKey # noqa: F401
from storage.base import Base
from storage.billing_session import BillingSession
from storage.conversation_work import ConversationWork
from storage.device_code import DeviceCode # noqa: F401
@@ -36,18 +30,9 @@ from storage.stripe_customer import StripeCustomer
from storage.user import User
@pytest.fixture(scope='function')
def db_path(tmp_path):
"""Create a unique temp file path for each test."""
return str(tmp_path / 'test.db')
@pytest.fixture
def engine(db_path):
"""Create a sync engine with tables using file-based DB."""
engine = create_engine(
f'sqlite:///{db_path}', connect_args={'check_same_thread': False}
)
def engine():
engine = create_engine('sqlite:///:memory:')
Base.metadata.create_all(engine)
return engine
@@ -57,36 +42,6 @@ def session_maker(engine):
return sessionmaker(bind=engine)
@pytest.fixture
def async_engine(db_path):
"""Create an async engine using the SAME file-based database."""
async_engine = create_async_engine(
f'sqlite+aiosqlite:///{db_path}',
connect_args={'check_same_thread': False},
)
async def create_tables():
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Run the async function synchronously
import asyncio
asyncio.run(create_tables())
return async_engine
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker bound to the async engine."""
async_session_maker = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
return async_session_maker
def add_minimal_fixtures(session_maker):
with session_maker() as session:
session.add(

View File

@@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from integrations.jira.jira_manager import JiraManager
from integrations.jira.jira_payload import JiraEventType, JiraWebhookPayload
from integrations.models import Message, SourceType
from openhands.server.types import (
LLMAuthenticationError,
@@ -273,8 +274,9 @@ class TestSendMessage:
return_value=mock_response
)
message = Message(source=SourceType.JIRA, message='Test message')
result = await jira_manager.send_message(
'Test message',
message,
'PROJ-123',
'cloud-123',
'service@test.com',

View File

@@ -1,268 +0,0 @@
"""
Tests for JiraPayloadParser.
These tests verify the parsing behavior of Jira webhook payloads,
including the handling of optional fields like user_email which
may not be present in webhook payloads from Jira.
"""
import pytest
from integrations.jira.jira_payload import (
JiraEventType,
JiraPayloadError,
JiraPayloadParser,
JiraPayloadSkipped,
JiraPayloadSuccess,
)
@pytest.fixture
def parser():
"""Create a JiraPayloadParser with standard OpenHands labels."""
return JiraPayloadParser(oh_label='openhands', inline_oh_label='@openhands')
@pytest.fixture
def valid_label_payload():
"""Create a valid jira:issue_updated payload with OpenHands label."""
return {
'webhookEvent': 'jira:issue_updated',
'issue': {
'id': '12345',
'key': 'TEST-123',
'self': 'https://test.atlassian.net/rest/api/2/issue/12345',
},
'user': {
'displayName': 'Test User',
'accountId': 'account-123',
'emailAddress': 'test@example.com',
},
'changelog': {
'items': [
{
'field': 'labels',
'toString': 'openhands',
}
]
},
}
@pytest.fixture
def valid_comment_payload():
"""Create a valid comment_created payload with OpenHands mention."""
return {
'webhookEvent': 'comment_created',
'issue': {
'id': '12345',
'key': 'TEST-123',
'self': 'https://test.atlassian.net/rest/api/2/issue/12345',
},
'comment': {
'body': '@openhands please fix this bug',
'author': {
'displayName': 'Test User',
'accountId': 'account-123',
'emailAddress': 'test@example.com',
},
},
}
class TestUserEmailOptional:
"""Tests verifying user_email is optional in webhook payloads.
Jira webhooks may not include emailAddress in the user data.
The parser should accept payloads without this field.
"""
def test_label_event_succeeds_without_email_address(
self, parser, valid_label_payload
):
"""Verify label event parsing succeeds when emailAddress is missing."""
# Arrange - remove emailAddress from user data
del valid_label_payload['user']['emailAddress']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.user_email == ''
assert result.payload.display_name == 'Test User'
assert result.payload.account_id == 'account-123'
def test_comment_event_succeeds_without_email_address(
self, parser, valid_comment_payload
):
"""Verify comment event parsing succeeds when emailAddress is missing."""
# Arrange - remove emailAddress from author data
del valid_comment_payload['comment']['author']['emailAddress']
# Act
result = parser.parse(valid_comment_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.user_email == ''
assert result.payload.display_name == 'Test User'
assert result.payload.account_id == 'account-123'
def test_user_email_preserved_when_present(self, parser, valid_label_payload):
"""Verify user_email is captured when emailAddress is present."""
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.user_email == 'test@example.com'
class TestRequiredFieldValidation:
"""Tests verifying required fields are still validated."""
def test_missing_issue_id_returns_error(self, parser, valid_label_payload):
"""Verify parsing fails when issue.id is missing."""
# Arrange
del valid_label_payload['issue']['id']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadError)
assert 'issue.id' in result.error
def test_missing_issue_key_returns_error(self, parser, valid_label_payload):
"""Verify parsing fails when issue.key is missing."""
# Arrange
del valid_label_payload['issue']['key']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadError)
assert 'issue.key' in result.error
def test_missing_display_name_returns_error(self, parser, valid_label_payload):
"""Verify parsing fails when user.displayName is missing."""
# Arrange
del valid_label_payload['user']['displayName']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadError)
assert 'displayName' in result.error
def test_missing_account_id_returns_error(self, parser, valid_label_payload):
"""Verify parsing fails when user.accountId is missing."""
# Arrange
del valid_label_payload['user']['accountId']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadError)
assert 'accountId' in result.error
def test_missing_issue_self_url_returns_error(self, parser, valid_label_payload):
"""Verify parsing fails when issue.self URL is missing."""
# Arrange
del valid_label_payload['issue']['self']
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadError)
assert 'workspace_name' in result.error or 'base_api_url' in result.error
class TestEventTypeDetection:
"""Tests for webhook event type detection."""
def test_issue_updated_with_label_returns_labeled_ticket(
self, parser, valid_label_payload
):
"""Verify jira:issue_updated with label is detected as LABELED_TICKET."""
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.event_type == JiraEventType.LABELED_TICKET
def test_comment_created_with_mention_returns_comment_mention(
self, parser, valid_comment_payload
):
"""Verify comment_created with mention is detected as COMMENT_MENTION."""
# Act
result = parser.parse(valid_comment_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.event_type == JiraEventType.COMMENT_MENTION
def test_unhandled_event_type_returns_skipped(self, parser):
"""Verify unknown event types are skipped."""
# Arrange
payload = {'webhookEvent': 'jira:issue_deleted'}
# Act
result = parser.parse(payload)
# Assert
assert isinstance(result, JiraPayloadSkipped)
assert 'Unhandled' in result.skip_reason
class TestLabelFiltering:
"""Tests for OpenHands label filtering."""
def test_label_event_without_openhands_label_skipped(
self, parser, valid_label_payload
):
"""Verify label events without OpenHands label are skipped."""
# Arrange - change label to something else
valid_label_payload['changelog']['items'][0]['toString'] = 'other-label'
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadSkipped)
assert 'openhands' in result.skip_reason
class TestCommentFiltering:
"""Tests for OpenHands comment mention filtering."""
def test_comment_without_mention_skipped(self, parser, valid_comment_payload):
"""Verify comments without OpenHands mention are skipped."""
# Arrange - remove mention from comment body
valid_comment_payload['comment']['body'] = 'Please fix this bug'
# Act
result = parser.parse(valid_comment_payload)
# Assert
assert isinstance(result, JiraPayloadSkipped)
assert '@openhands' in result.skip_reason
class TestWorkspaceExtraction:
"""Tests for workspace name extraction from issue URL."""
def test_workspace_name_extracted_from_self_url(self, parser, valid_label_payload):
"""Verify workspace name is extracted from issue self URL."""
# Act
result = parser.parse(valid_label_payload)
# Assert
assert isinstance(result, JiraPayloadSuccess)
assert result.payload.workspace_name == 'test.atlassian.net'
assert result.payload.base_api_url == 'https://test.atlassian.net'

View File

@@ -738,7 +738,7 @@ class TestStartJob:
# Should send error message about re-login
jira_dc_manager.send_message.assert_called_once()
call_args = jira_dc_manager.send_message.call_args[0]
assert 'Please re-login' in call_args[0]
assert 'Please re-login' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_llm_authentication_error(
@@ -763,7 +763,7 @@ class TestStartJob:
# Should send error message about LLM API key
jira_dc_manager.send_message.assert_called_once()
call_args = jira_dc_manager.send_message.call_args[0]
assert 'valid LLM API key' in call_args[0]
assert 'valid LLM API key' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_session_expired_error(
@@ -788,8 +788,8 @@ class TestStartJob:
# Should send error message about session expired
jira_dc_manager.send_message.assert_called_once()
call_args = jira_dc_manager.send_message.call_args[0]
assert 'session has expired' in call_args[0]
assert 'login again' in call_args[0]
assert 'session has expired' in call_args[0].message
assert 'login again' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_unexpected_error(
@@ -814,7 +814,7 @@ class TestStartJob:
# Should send generic error message
jira_dc_manager.send_message.assert_called_once()
call_args = jira_dc_manager.send_message.call_args[0]
assert 'unexpected error' in call_args[0]
assert 'unexpected error' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_send_message_fails(
@@ -943,8 +943,9 @@ class TestSendMessage:
return_value=mock_response
)
message = Message(source=SourceType.JIRA_DC, message='Test message')
result = await jira_dc_manager.send_message(
'Test message', 'PROJ-123', 'https://jira.company.com', 'bearer_token'
message, 'PROJ-123', 'https://jira.company.com', 'bearer_token'
)
assert result == {'id': 'comment_id'}
@@ -1013,7 +1014,7 @@ class TestSendRepoSelectionComment:
jira_dc_manager.send_message.assert_called_once()
call_args = jira_dc_manager.send_message.call_args[0]
assert 'which repository to work with' in call_args[0]
assert 'which repository to work with' in call_args[0].message
@pytest.mark.asyncio
async def test_send_repo_selection_comment_send_fails(

View File

@@ -18,11 +18,9 @@ from openhands.core.schema.agent import AgentState
class TestJiraDcNewConversationView:
"""Tests for JiraDcNewConversationView"""
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
"""Test _get_instructions method"""
instructions, user_msg = await new_conversation_view._get_instructions(
mock_jinja_env
)
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
assert instructions == 'Test Jira DC instructions template'
assert 'PROJ-123' in user_msg
@@ -85,9 +83,9 @@ class TestJiraDcNewConversationView:
class TestJiraDcExistingConversationView:
"""Tests for JiraDcExistingConversationView"""
async def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
"""Test _get_instructions method"""
instructions, user_msg = await existing_conversation_view._get_instructions(
instructions, user_msg = existing_conversation_view._get_instructions(
mock_jinja_env
)

View File

@@ -802,7 +802,7 @@ class TestStartJob:
# Should send error message about re-login
linear_manager.send_message.assert_called_once()
call_args = linear_manager.send_message.call_args[0]
assert 'Please re-login' in call_args[0]
assert 'Please re-login' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_llm_authentication_error(
@@ -828,7 +828,7 @@ class TestStartJob:
# Should send error message about LLM API key
linear_manager.send_message.assert_called_once()
call_args = linear_manager.send_message.call_args[0]
assert 'valid LLM API key' in call_args[0]
assert 'valid LLM API key' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_session_expired_error(
@@ -854,8 +854,8 @@ class TestStartJob:
# Should send error message about session expired
linear_manager.send_message.assert_called_once()
call_args = linear_manager.send_message.call_args[0]
assert 'session has expired' in call_args[0]
assert 'login again' in call_args[0]
assert 'session has expired' in call_args[0].message
assert 'login again' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_unexpected_error(
@@ -881,7 +881,7 @@ class TestStartJob:
# Should send generic error message
linear_manager.send_message.assert_called_once()
call_args = linear_manager.send_message.call_args[0]
assert 'unexpected error' in call_args[0]
assert 'unexpected error' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_send_message_fails(
@@ -1049,9 +1049,8 @@ class TestSendMessage:
linear_manager._query_api = AsyncMock(return_value=mock_response)
result = await linear_manager.send_message(
'Test message', 'issue_id', 'api_key'
)
message = Message(source=SourceType.LINEAR, message='Test message')
result = await linear_manager.send_message(message, 'issue_id', 'api_key')
assert result == mock_response
linear_manager._query_api.assert_called_once()
@@ -1115,7 +1114,7 @@ class TestSendRepoSelectionComment:
linear_manager.send_message.assert_called_once()
call_args = linear_manager.send_message.call_args[0]
assert 'which repository to work with' in call_args[0]
assert 'which repository to work with' in call_args[0].message
@pytest.mark.asyncio
async def test_send_repo_selection_comment_send_fails(

View File

@@ -18,11 +18,9 @@ from openhands.core.schema.agent import AgentState
class TestLinearNewConversationView:
"""Tests for LinearNewConversationView"""
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
"""Test _get_instructions method"""
instructions, user_msg = await new_conversation_view._get_instructions(
mock_jinja_env
)
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
assert instructions == 'Test instructions template'
assert 'TEST-123' in user_msg
@@ -85,9 +83,9 @@ class TestLinearNewConversationView:
class TestLinearExistingConversationView:
"""Tests for LinearExistingConversationView"""
async def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
"""Test _get_instructions method"""
instructions, user_msg = await existing_conversation_view._get_instructions(
instructions, user_msg = existing_conversation_view._get_instructions(
mock_jinja_env
)

View File

@@ -263,9 +263,7 @@ class TestPausedSandboxResumption:
@patch('openhands.app_server.config.get_httpx_client')
@patch('openhands.app_server.event_callback.util.ensure_running_sandbox')
@patch('openhands.app_server.event_callback.util.get_agent_server_url_from_sandbox')
@patch.object(
SlackUpdateExistingConversationView, '_get_instructions', new_callable=AsyncMock
)
@patch.object(SlackUpdateExistingConversationView, '_get_instructions')
async def test_paused_sandbox_resumption(
self,
mock_get_instructions,

View File

@@ -34,6 +34,7 @@ async def test_send_comment_to_jira_success(mock_jira_manager, processor):
)
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_manager.send_message = AsyncMock()
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
# Action
await processor._send_comment_to_jira('This is a summary.')
@@ -129,6 +130,7 @@ async def test_call_sends_summary_to_jira(
return_value=mock_workspace
)
mock_jira_manager.send_message = AsyncMock()
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.jira_callback_processor.asyncio.create_task'
@@ -198,6 +200,7 @@ async def test_send_comment_to_jira_api_error(mock_jira_manager, processor):
)
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
# Action - should not raise exception, but handle it gracefully
await processor._send_comment_to_jira('This is a summary.')
@@ -325,15 +328,18 @@ async def test_send_comment_to_jira_message_construction(mock_jira_manager, proc
)
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_manager.send_message = AsyncMock()
mock_outgoing_message = MagicMock()
mock_jira_manager.create_outgoing_message.return_value = mock_outgoing_message
test_message = 'This is a test summary message.'
# Action
await processor._send_comment_to_jira(test_message)
# Assert - send_message now receives the string directly
# Assert
mock_jira_manager.create_outgoing_message.assert_called_once_with(msg=test_message)
mock_jira_manager.send_message.assert_called_once_with(
test_message,
mock_outgoing_message,
issue_key='TEST-123',
jira_cloud_id='cloud123',
svc_acc_email='service@test.com',
@@ -380,6 +386,7 @@ async def test_call_creates_background_task_for_sending(
return_value=mock_workspace
)
mock_jira_manager.send_message = AsyncMock()
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.jira_callback_processor.asyncio.create_task'

View File

@@ -32,6 +32,7 @@ async def test_send_comment_to_jira_dc_success(mock_jira_dc_manager, processor):
)
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_dc_manager.send_message = AsyncMock()
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
# Action
await processor._send_comment_to_jira_dc('This is a summary.')
@@ -124,6 +125,7 @@ async def test_call_sends_summary_to_jira_dc(
return_value=mock_workspace
)
mock_jira_dc_manager.send_message = AsyncMock()
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.jira_dc_callback_processor.asyncio.create_task'
@@ -198,6 +200,7 @@ async def test_send_comment_to_jira_dc_api_error(mock_jira_dc_manager, processor
)
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_dc_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
# Action - should not raise exception, but handle it gracefully
await processor._send_comment_to_jira_dc('This is a summary.')
@@ -325,15 +328,20 @@ async def test_send_comment_to_jira_dc_message_construction(
)
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_jira_dc_manager.send_message = AsyncMock()
mock_outgoing_message = MagicMock()
mock_jira_dc_manager.create_outgoing_message.return_value = mock_outgoing_message
test_message = 'This is a test summary message.'
# Action
await processor._send_comment_to_jira_dc(test_message)
# Assert - send_message now receives the string directly
# Assert
mock_jira_dc_manager.create_outgoing_message.assert_called_once_with(
msg=test_message
)
mock_jira_dc_manager.send_message.assert_called_once_with(
test_message,
mock_outgoing_message,
issue_key='TEST-123',
base_api_url='https://test-jira-dc.company.com',
svc_acc_api_key='decrypted_key',
@@ -376,6 +384,7 @@ async def test_call_creates_background_task_for_sending(
return_value=mock_workspace
)
mock_jira_dc_manager.send_message = AsyncMock()
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.jira_dc_callback_processor.asyncio.create_task'

View File

@@ -32,6 +32,7 @@ async def test_send_comment_to_linear_success(mock_linear_manager, processor):
)
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_linear_manager.send_message = AsyncMock()
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
# Action
await processor._send_comment_to_linear('This is a summary.')
@@ -124,6 +125,7 @@ async def test_call_sends_summary_to_linear(
return_value=mock_workspace
)
mock_linear_manager.send_message = AsyncMock()
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.linear_callback_processor.asyncio.create_task'
@@ -198,6 +200,7 @@ async def test_send_comment_to_linear_api_error(mock_linear_manager, processor):
)
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_linear_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
# Action - should not raise exception, but handle it gracefully
await processor._send_comment_to_linear('This is a summary.')
@@ -325,15 +328,20 @@ async def test_send_comment_to_linear_message_construction(
)
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
mock_linear_manager.send_message = AsyncMock()
mock_outgoing_message = MagicMock()
mock_linear_manager.create_outgoing_message.return_value = mock_outgoing_message
test_message = 'This is a test summary message.'
# Action
await processor._send_comment_to_linear(test_message)
# Assert - send_message now receives the string directly
# Assert
mock_linear_manager.create_outgoing_message.assert_called_once_with(
msg=test_message
)
mock_linear_manager.send_message.assert_called_once_with(
test_message,
mock_outgoing_message,
'TEST-123', # issue_id
'decrypted_key', # api_key
)
@@ -375,6 +383,7 @@ async def test_call_creates_background_task_for_sending(
return_value=mock_workspace
)
mock_linear_manager.send_message = AsyncMock()
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
with patch(
'server.conversation_callback_processor.linear_callback_processor.asyncio.create_task'

View File

@@ -16,15 +16,8 @@ from storage.device_code import DeviceCode
@pytest.fixture
def mock_device_code_store():
"""Mock device code store with async methods."""
mock = MagicMock()
mock.create_device_code = AsyncMock()
mock.get_by_device_code = AsyncMock()
mock.get_by_user_code = AsyncMock()
mock.authorize_device_code = AsyncMock()
mock.deny_device_code = AsyncMock()
mock.update_poll_time = AsyncMock()
return mock
"""Mock device code store."""
return MagicMock()
@pytest.fixture
@@ -61,7 +54,7 @@ class TestDeviceAuthorization:
expires_at=datetime.now(UTC) + timedelta(minutes=10),
current_interval=5, # Default interval
)
mock_store.create_device_code = AsyncMock(return_value=mock_device)
mock_store.create_device_code.return_value = mock_device
result = await device_authorization(mock_request)
@@ -83,7 +76,7 @@ class TestDeviceAuthorization:
expires_at=datetime.now(UTC) + timedelta(minutes=10),
current_interval=15, # Increased interval from previous rate limiting
)
mock_store.create_device_code = AsyncMock(return_value=mock_device)
mock_store.create_device_code.return_value = mock_device
result = await device_authorization(mock_request)
@@ -120,10 +113,10 @@ class TestDeviceToken:
mock_device.status = status
# Mock rate limiting - return False (not too fast) and default interval
mock_device.check_rate_limit.return_value = (False, 5)
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
else:
mock_store.get_by_device_code = AsyncMock(return_value=None)
mock_store.get_by_device_code.return_value = None
result = await device_token(device_code=device_code)
@@ -149,14 +142,12 @@ class TestDeviceToken:
)
# Mock rate limiting - return False (not too fast) and default interval
mock_device.check_rate_limit.return_value = (False, 5)
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
# Mock API key retrieval - use AsyncMock for async method
# Mock API key retrieval
mock_api_key_store = MagicMock()
mock_api_key_store.retrieve_api_key_by_name = AsyncMock(
return_value='test-api-key'
)
mock_api_key_store.retrieve_api_key_by_name.return_value = 'test-api-key'
mock_api_key_class.get_instance.return_value = mock_api_key_store
result = await device_token(device_code=device_code)
@@ -185,7 +176,7 @@ class TestDeviceVerificationAuthenticated:
self, mock_store, mock_api_key_class
):
"""Test verification with invalid device code."""
mock_store.get_by_user_code = AsyncMock(return_value=None)
mock_store.get_by_user_code.return_value = None
with pytest.raises(HTTPException):
await device_verification_authenticated(
@@ -198,7 +189,7 @@ class TestDeviceVerificationAuthenticated:
"""Test verification with already processed device code."""
mock_device = MagicMock()
mock_device.is_pending.return_value = False
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
mock_store.get_by_user_code.return_value = mock_device
with pytest.raises(HTTPException):
await device_verification_authenticated(
@@ -212,8 +203,8 @@ class TestDeviceVerificationAuthenticated:
# Mock device code
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
mock_store.authorize_device_code = AsyncMock(return_value=True)
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = True
# Mock API key store with async create_api_key
mock_api_key_store = MagicMock()
@@ -257,17 +248,15 @@ class TestDeviceVerificationAuthenticated:
mock_device2.is_pending.return_value = True
# Configure mock store to return appropriate device for each user_code
async def get_by_user_code_side_effect(user_code):
def get_by_user_code_side_effect(user_code):
if user_code == device1_code:
return mock_device1
elif user_code == device2_code:
return mock_device2
return None
mock_store.get_by_user_code = AsyncMock(
side_effect=get_by_user_code_side_effect
)
mock_store.authorize_device_code = AsyncMock(return_value=True)
mock_store.get_by_user_code.side_effect = get_by_user_code_side_effect
mock_store.authorize_device_code.return_value = True
# Authenticate first device
result1 = await device_verification_authenticated(
@@ -316,8 +305,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=None, # First poll
current_interval=5,
)
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -347,8 +336,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=last_poll,
current_interval=5,
)
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -378,8 +367,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=last_poll,
current_interval=5,
)
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -410,8 +399,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=last_poll,
current_interval=15, # Already increased from previous slow_down
)
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -441,8 +430,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=last_poll,
current_interval=58, # Near maximum of 60
)
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -468,8 +457,8 @@ class TestDeviceTokenRateLimiting:
last_poll_time=last_poll,
current_interval=5,
)
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
mock_store.update_poll_time = AsyncMock(return_value=True)
mock_store.get_by_device_code.return_value = mock_device
mock_store.update_poll_time.return_value = True
device_code = 'test_device_code'
result = await device_token(device_code=device_code)
@@ -498,10 +487,8 @@ class TestDeviceVerificationTransactionIntegrity:
# Mock device code
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
mock_store.authorize_device_code = AsyncMock(
return_value=False
) # Authorization fails
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = False # Authorization fails
# Mock API key store with async create_api_key
mock_api_key_store = MagicMock()
@@ -532,11 +519,9 @@ class TestDeviceVerificationTransactionIntegrity:
# Mock device code
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
mock_store.authorize_device_code = AsyncMock(
return_value=True
) # Authorization succeeds
mock_store.deny_device_code = AsyncMock(return_value=True) # Cleanup succeeds
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = True # Authorization succeeds
mock_store.deny_device_code.return_value = True # Cleanup succeeds
# Mock API key store to fail on creation (async)
mock_api_key_store = MagicMock()
@@ -574,12 +559,10 @@ class TestDeviceVerificationTransactionIntegrity:
# Mock device code
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
mock_store.authorize_device_code = AsyncMock(
return_value=True
) # Authorization succeeds
mock_store.deny_device_code = AsyncMock(
side_effect=Exception('Cleanup failed')
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = True # Authorization succeeds
mock_store.deny_device_code.side_effect = Exception(
'Cleanup failed'
) # Cleanup fails
# Mock API key store to fail on creation (async)
@@ -612,11 +595,8 @@ class TestDeviceVerificationTransactionIntegrity:
# Mock device code
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
mock_store.authorize_device_code = AsyncMock(
return_value=True
) # Authorization succeeds
mock_store.deny_device_code = AsyncMock()
mock_store.get_by_user_code.return_value = mock_device
mock_store.authorize_device_code.return_value = True # Authorization succeeds
# Mock API key store with async create_api_key
mock_api_key_store = MagicMock()

View File

@@ -11,37 +11,43 @@ import httpx
import pytest
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.testclient import TestClient
from server.email_validation import get_admin_user_id
from server.routes.org_models import (
CannotModifySelfError,
InsufficientPermissionError,
InvalidRoleError,
LastOwnerError,
LiteLLMIntegrationError,
MeResponse,
OrgAppSettingsResponse,
OrgAppSettingsUpdate,
OrgAuthorizationError,
OrgDatabaseError,
OrgMemberNotFoundError,
OrgMemberPage,
OrgMemberResponse,
OrgMemberUpdate,
OrgNameExistsError,
OrgNotFoundError,
OrphanedUserError,
RoleNotFoundError,
)
from server.routes.orgs import (
get_me,
get_org_members,
org_router,
remove_org_member,
update_org_member,
)
from storage.org import Org
from openhands.server.user_auth import get_user_id
# Mock database before imports
with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from server.email_validation import get_admin_user_id
from server.routes.org_models import (
CannotModifySelfError,
InsufficientPermissionError,
InvalidRoleError,
LastOwnerError,
LiteLLMIntegrationError,
MeResponse,
OrgAppSettingsResponse,
OrgAppSettingsUpdate,
OrgAuthorizationError,
OrgDatabaseError,
OrgMemberNotFoundError,
OrgMemberPage,
OrgMemberResponse,
OrgMemberUpdate,
OrgNameExistsError,
OrgNotFoundError,
OrphanedUserError,
RoleNotFoundError,
)
from server.routes.orgs import (
get_me,
get_org_members,
org_router,
remove_org_member,
update_org_member,
)
from storage.org import Org
from openhands.server.user_auth import get_user_id
# Test user ID constant (must be a valid UUID string)
TEST_USER_ID = str(uuid.uuid4())

View File

@@ -399,135 +399,3 @@ class TestUpdateActiveWorkingSeconds:
assert conversation_work.seconds == 23.0
assert conversation_work.conversation_id == conversation_id
assert conversation_work.user_id == user_id
class TestInvokeConversationCallbacks:
"""Tests for invoke_conversation_callbacks function.
This function uses async database sessions (a_session_maker) to query
and invoke callbacks for a conversation.
"""
@pytest.fixture
def mock_observation(self):
"""Create a mock AgentStateChangedObservation."""
observation = Mock(spec=AgentStateChangedObservation)
observation.agent_state = AgentState.FINISHED
return observation
@pytest.fixture
def create_mock_async_session(self):
"""Factory to create properly mocked async session context manager."""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock
def _create(callbacks_list):
mock_session = Mock()
mock_result = Mock()
mock_result.scalars.return_value.all.return_value = callbacks_list
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock(return_value=None)
@asynccontextmanager
async def mock_context_manager():
yield mock_session
return mock_context_manager, mock_session
return _create
@pytest.mark.asyncio
async def test_invoke_callbacks_with_active_callbacks(
self, mock_observation, create_mock_async_session
):
"""Test that active callbacks are invoked successfully."""
from unittest.mock import AsyncMock
# Arrange
conversation_id = 'test_conversation_callbacks'
mock_processor = AsyncMock(return_value=None)
# Create a mock callback
mock_callback = Mock()
mock_callback.id = 1
mock_callback.processor_type = 'test_processor'
mock_callback.get_processor.return_value = mock_processor
mock_context_manager, mock_session = create_mock_async_session([mock_callback])
# Act
with patch(
'server.utils.conversation_callback_utils.a_session_maker',
mock_context_manager,
):
from server.utils.conversation_callback_utils import (
invoke_conversation_callbacks,
)
await invoke_conversation_callbacks(conversation_id, mock_observation)
# Assert
mock_callback.get_processor.assert_called_once()
mock_processor.assert_called_once_with(mock_callback, mock_observation)
@pytest.mark.asyncio
async def test_invoke_callbacks_with_no_active_callbacks(
self, mock_observation, create_mock_async_session
):
"""Test behavior when no active callbacks exist."""
# Arrange
conversation_id = 'test_no_callbacks'
mock_context_manager, mock_session = create_mock_async_session([])
# Act
with patch(
'server.utils.conversation_callback_utils.a_session_maker',
mock_context_manager,
):
from server.utils.conversation_callback_utils import (
invoke_conversation_callbacks,
)
await invoke_conversation_callbacks(conversation_id, mock_observation)
# Assert - should complete without errors
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_invoke_callbacks_handles_processor_exception(
self, mock_observation, create_mock_async_session
):
"""Test that processor exceptions are caught and callback status is updated."""
from unittest.mock import AsyncMock
# Arrange
conversation_id = 'test_callback_error'
mock_processor = AsyncMock(side_effect=Exception('Processor error'))
mock_callback = Mock()
mock_callback.id = 1
mock_callback.processor_type = 'failing_processor'
mock_callback.get_processor.return_value = mock_processor
mock_callback.status = 'active'
mock_context_manager, mock_session = create_mock_async_session([mock_callback])
# Act
with patch(
'server.utils.conversation_callback_utils.a_session_maker',
mock_context_manager,
), patch('server.utils.conversation_callback_utils.logger') as mock_logger:
from server.utils.conversation_callback_utils import (
invoke_conversation_callbacks,
)
from storage.conversation_callback import CallbackStatus
await invoke_conversation_callbacks(conversation_id, mock_observation)
# Assert - callback status should be set to ERROR
assert mock_callback.status == CallbackStatus.ERROR
mock_logger.error.assert_called_once()
error_call = mock_logger.error.call_args
assert error_call[0][0] == 'callback_invocation_failed'

View File

@@ -1,127 +1,127 @@
"""Unit tests for AuthTokenStore using SQLite in-memory database."""
"""Unit tests for AuthTokenStore."""
import time
from unittest.mock import patch
from contextlib import asynccontextmanager
from typing import Dict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from server.auth.auth_error import TokenRefreshError
from sqlalchemy.exc import OperationalError
from storage.auth_token_store import (
ACCESS_TOKEN_EXPIRY_BUFFER,
LOCK_TIMEOUT_SECONDS,
AuthTokenStore,
)
from storage.auth_tokens import AuthTokens
from storage.base import Base
from openhands.integrations.service_types import ProviderType
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
)
return engine
def create_mock_session():
"""Create a mock async session with properly configured context managers."""
session = AsyncMock()
# Create async context manager for begin()
@asynccontextmanager
async def begin_context():
yield
session.begin = begin_context
return session
def create_mock_session_maker(mock_session):
"""Create a mock async session maker."""
@asynccontextmanager
async def session_context():
yield mock_session
# Return a callable that returns the context manager
return lambda: session_context()
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker bound to the async engine."""
async_session_maker = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
def mock_session():
"""Create mock async session."""
return create_mock_session()
@pytest.fixture
def mock_session_maker(mock_session):
"""Create mock async session maker."""
return create_mock_session_maker(mock_session)
@pytest.fixture
def auth_token_store(mock_session_maker):
"""Create AuthTokenStore instance with mocked session maker."""
return AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
# Create all tables
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
return async_session_maker
class TestIsTokenExpired:
"""Tests for _is_token_expired method."""
def test_both_tokens_valid(self):
def test_both_tokens_valid(self, auth_token_store):
"""Test when both tokens are valid (not expired)."""
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
current_time = int(time.time())
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
refresh_expires = current_time + 1000
access_expired, refresh_expired = store._is_token_expired(
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is False
assert refresh_expired is False
def test_access_token_expired(self):
def test_access_token_expired(self, auth_token_store):
"""Test when access token is expired but within buffer."""
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
current_time = int(time.time())
# Access token expires within buffer period
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER - 100
refresh_expires = current_time + 10000
access_expired, refresh_expired = store._is_token_expired(
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is True
assert refresh_expired is False
def test_refresh_token_expired(self):
def test_refresh_token_expired(self, auth_token_store):
"""Test when refresh token is expired."""
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
current_time = int(time.time())
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
refresh_expires = current_time - 100 # Already expired
access_expired, refresh_expired = store._is_token_expired(
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is False
assert refresh_expired is True
def test_both_tokens_expired(self):
def test_both_tokens_expired(self, auth_token_store):
"""Test when both tokens are expired."""
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
current_time = int(time.time())
access_expires = current_time - 100
refresh_expires = current_time - 100
access_expired, refresh_expired = store._is_token_expired(
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is True
assert refresh_expired is True
def test_zero_expiration_treated_as_never_expires(self):
def test_zero_expiration_treated_as_never_expires(self, auth_token_store):
"""Test that 0 expiration time is treated as never expires."""
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
access_expired, refresh_expired = store._is_token_expired(0, 0)
access_expired, refresh_expired = auth_token_store._is_token_expired(0, 0)
assert access_expired is False
assert refresh_expired is False
@@ -131,188 +131,427 @@ class TestLoadTokensFastPath:
"""Tests for load_tokens fast path (no lock needed)."""
@pytest.mark.asyncio
async def test_fast_path_token_not_found(self, async_session_maker):
async def test_fast_path_token_not_found(
self, auth_token_store, mock_session_maker, mock_session
):
"""Test fast path returns None when no token record exists."""
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
result = await store.load_tokens()
result = await auth_token_store.load_tokens()
assert result is None
assert result is None
@pytest.mark.asyncio
async def test_fast_path_valid_token_no_refresh_needed(self, async_session_maker):
async def test_fast_path_valid_token_no_refresh_needed(
self, auth_token_store, mock_session_maker, mock_session
):
"""Test fast path returns tokens when they are still valid."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'valid-access-token'
mock_token.refresh_token = 'valid-refresh-token'
mock_token.access_token_expires_at = (
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
)
mock_token.refresh_token_expires_at = current_time + 10000
# First, store a valid token in the database
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
await store.store_tokens(
access_token='valid-access-token',
refresh_token='valid-refresh-token',
access_token_expires_at=current_time
+ ACCESS_TOKEN_EXPIRY_BUFFER
+ 1000,
refresh_token_expires_at=current_time + 10000,
)
result = await auth_token_store.load_tokens()
# Now load tokens - should return valid tokens without refresh
result = await store.load_tokens()
assert result is not None
assert result['access_token'] == 'valid-access-token'
assert result['refresh_token'] == 'valid-refresh-token'
assert result is not None
assert result['access_token'] == 'valid-access-token'
assert result['refresh_token'] == 'valid-refresh-token'
@pytest.mark.asyncio
async def test_fast_path_no_refresh_callback_provided(self, async_session_maker):
async def test_fast_path_no_refresh_callback_provided(
self, auth_token_store, mock_session_maker, mock_session
):
"""Test fast path returns existing tokens when no refresh callback is provided."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'expired-access-token'
mock_token.refresh_token = 'valid-refresh-token'
# Expired access token
mock_token.access_token_expires_at = current_time - 100
mock_token.refresh_token_expires_at = current_time + 10000
# Store expired access token
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
await store.store_tokens(
access_token='expired-access-token',
refresh_token='valid-refresh-token',
access_token_expires_at=current_time - 100, # Expired
refresh_token_expires_at=current_time + 10000,
)
result = await auth_token_store.load_tokens(check_expiration_and_refresh=None)
# Load without refresh callback - should still return tokens
result = await store.load_tokens(check_expiration_and_refresh=None)
assert result is not None
assert result['access_token'] == 'expired-access-token'
assert result is not None
assert result['access_token'] == 'expired-access-token'
class TestLoadTokensSlowPath:
"""Tests for load_tokens slow path (lock required for refresh).
"""Tests for load_tokens slow path (lock required for refresh)."""
Note: These tests require PostgreSQL's lock_timeout feature which is not
available in SQLite. The slow path tests are skipped when using SQLite.
"""
@pytest.mark.skip(reason='SQLite does not support PostgreSQL lock_timeout syntax')
@pytest.mark.asyncio
async def test_slow_path_successful_refresh(self, async_session_maker):
async def test_slow_path_successful_refresh(self):
"""Test slow path successfully refreshes expired tokens."""
pass
@pytest.mark.skip(reason='SQLite does not support PostgreSQL lock_timeout syntax')
@pytest.mark.asyncio
async def test_refresh_callback_returns_none(self, async_session_maker):
"""Test behavior when refresh callback returns None (no refresh performed)."""
pass
@pytest.mark.asyncio
async def test_slow_path_double_check_avoids_refresh(self, async_session_maker):
"""Test double-check pattern avoids unnecessary refresh."""
current_time = int(time.time())
mock_session = create_mock_session()
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
# First call (fast path) - returns expired token
# Second call (slow path) - returns same token for update
expired_token = MagicMock()
expired_token.id = 1
expired_token.access_token = 'expired-access-token'
expired_token.refresh_token = 'valid-refresh-token'
expired_token.access_token_expires_at = current_time - 100 # Expired
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
) -> Dict[str, str | int]:
return {
'access_token': 'new-access-token',
'refresh_token': 'new-refresh-token',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
assert result is not None
assert result['access_token'] == 'new-access-token'
assert result['refresh_token'] == 'new-refresh-token'
@pytest.mark.asyncio
async def test_slow_path_double_check_avoids_refresh(self):
"""Test double-check locking: token was refreshed by another request."""
current_time = int(time.time())
mock_session = create_mock_session()
# Simulate scenario:
# 1. Fast path sees expired token
# 2. While waiting for lock, another request refreshes
# 3. Slow path sees fresh token, skips refresh
call_count = [0]
def create_token():
call_count[0] += 1
token = MagicMock()
token.id = 1
token.access_token = 'fresh-access-token'
token.refresh_token = 'fresh-refresh-token'
if call_count[0] == 1:
# First call (fast path) - expired
token.access_token_expires_at = current_time - 100
else:
# Second call (slow path) - already refreshed
token.access_token_expires_at = (
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
)
token.refresh_token_expires_at = current_time + 86400
return token
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.side_effect = (
lambda: create_token()
)
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
refresh_called = [False]
async def mock_refresh(
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
) -> Dict[str, str | int]:
refresh_called[0] = True
return {
'access_token': 'should-not-be-used',
'refresh_token': 'should-not-be-used',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
# The refresh callback should not be called because double-check
# found the token was already refreshed
assert result is not None
assert result['access_token'] == 'fresh-access-token'
@pytest.mark.asyncio
async def test_slow_path_token_not_found_after_lock(self):
"""Test slow path returns None if token record disappears after lock."""
current_time = int(time.time())
mock_session = create_mock_session()
# First call (fast path) - token exists but expired
# Second call (slow path with lock) - token no longer exists
call_count = [0]
def get_token():
call_count[0] += 1
if call_count[0] == 1:
token = MagicMock()
token.access_token_expires_at = current_time - 100 # Expired
token.refresh_token_expires_at = current_time + 10000
return token
return None
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.side_effect = get_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(*args) -> Dict[str, str | int]:
return {
'access_token': 'new-token',
'refresh_token': 'new-refresh',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
assert result is None
class TestLoadTokensLockTimeout:
"""Tests for lock timeout handling."""
@pytest.mark.asyncio
async def test_lock_timeout_raises_token_refresh_error(self):
"""Test that lock timeout raises TokenRefreshError."""
current_time = int(time.time())
mock_session = create_mock_session()
# First call (fast path) - returns expired token
expired_token = MagicMock()
expired_token.access_token_expires_at = current_time - 100
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
# First execute for fast path succeeds
# Second execute (for slow path) raises OperationalError
call_count = [0]
async def execute_side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] <= 1:
return mock_result
# Simulate lock timeout
raise OperationalError(
'canceling statement due to lock timeout', None, None
)
# Store a token that will be valid when second check happens
await store.store_tokens(
access_token='original-access-token',
refresh_token='valid-refresh-token',
access_token_expires_at=current_time
+ ACCESS_TOKEN_EXPIRY_BUFFER
+ 1000,
refresh_token_expires_at=current_time + 10000,
)
mock_session.execute = execute_side_effect
# Load with refresh callback - should NOT refresh since token is valid
result = await store.load_tokens()
mock_session_maker = create_mock_session_maker(mock_session)
assert result is not None
assert result['access_token'] == 'original-access-token'
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(*args) -> Dict[str, str | int]:
return {
'access_token': 'new-token',
'refresh_token': 'new-refresh',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
with pytest.raises(TokenRefreshError) as exc_info:
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
assert 'lock timeout' in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_lock_timeout_preserves_original_exception(self):
"""Test that TokenRefreshError preserves the original OperationalError."""
current_time = int(time.time())
mock_session = create_mock_session()
expired_token = MagicMock()
expired_token.access_token_expires_at = current_time - 100
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
original_error = OperationalError(
'canceling statement due to lock timeout', None, None
)
call_count = [0]
async def execute_side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] <= 1:
return mock_result
raise original_error
mock_session.execute = execute_side_effect
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(*args) -> Dict[str, str | int]:
return {
'access_token': 'new-token',
'refresh_token': 'new-refresh',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
with pytest.raises(TokenRefreshError) as exc_info:
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
# Verify the original exception is chained
assert exc_info.value.__cause__ is original_error
class TestLoadTokensRefreshCallbackBehavior:
"""Tests for refresh callback return values."""
@pytest.mark.asyncio
async def test_refresh_callback_returns_none(self):
"""Test behavior when refresh callback returns None (no refresh performed)."""
current_time = int(time.time())
mock_session = create_mock_session()
expired_token = MagicMock()
expired_token.id = 1
expired_token.access_token = 'old-access-token'
expired_token.refresh_token = 'old-refresh-token'
expired_token.access_token_expires_at = current_time - 100 # Expired
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh_returns_none(
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
) -> Dict[str, str | int] | None:
return None
result = await auth_store.load_tokens(
check_expiration_and_refresh=mock_refresh_returns_none
)
# Should return the old tokens when refresh returns None
assert result is not None
assert result['access_token'] == 'old-access-token'
assert result['refresh_token'] == 'old-refresh-token'
class TestStoreTokens:
"""Tests for store_tokens method."""
@pytest.mark.asyncio
async def test_store_tokens_creates_new_record(self, async_session_maker):
async def test_store_tokens_creates_new_record(self):
"""Test storing tokens when no existing record."""
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
mock_session = create_mock_session()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
await store.store_tokens(
access_token='new-access-token',
refresh_token='new-refresh-token',
access_token_expires_at=1234567890,
refresh_token_expires_at=1234657890,
)
mock_session_maker = create_mock_session_maker(mock_session)
# Verify the token was stored
async with async_session_maker() as session:
result = await session.execute(
select(AuthTokens).where(
AuthTokens.keycloak_user_id == 'test-user-123',
AuthTokens.identity_provider == ProviderType.GITHUB.value,
)
)
token_record = result.scalars().first()
assert token_record is not None
assert token_record.access_token == 'new-access-token'
assert token_record.refresh_token == 'new-refresh-token'
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
await auth_store.store_tokens(
access_token='new-access-token',
refresh_token='new-refresh-token',
access_token_expires_at=1234567890,
refresh_token_expires_at=1234657890,
)
mock_session.add.assert_called_once()
@pytest.mark.asyncio
async def test_store_tokens_updates_existing_record(self, async_session_maker):
async def test_store_tokens_updates_existing_record(self):
"""Test storing tokens updates existing record."""
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
mock_session = create_mock_session()
existing_token = MagicMock()
existing_token.access_token = 'old-access'
# First, create a token record
await store.store_tokens(
access_token='old-access-token',
refresh_token='old-refresh-token',
access_token_expires_at=1234567890,
refresh_token_expires_at=1234657890,
)
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = existing_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
# Now update it
await store.store_tokens(
access_token='new-access-token',
refresh_token='new-refresh-token',
access_token_expires_at=1234567891,
refresh_token_expires_at=1234657891,
)
mock_session_maker = create_mock_session_maker(mock_session)
# Verify the token was updated
async with async_session_maker() as session:
result = await session.execute(
select(AuthTokens).where(
AuthTokens.keycloak_user_id == 'test-user-123',
AuthTokens.identity_provider == ProviderType.GITHUB.value,
)
)
token_record = result.scalars().first()
assert token_record is not None
assert token_record.access_token == 'new-access-token'
assert token_record.refresh_token == 'new-refresh-token'
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
await auth_store.store_tokens(
access_token='new-access-token',
refresh_token='new-refresh-token',
access_token_expires_at=1234567890,
refresh_token_expires_at=1234657890,
)
assert existing_token.access_token == 'new-access-token'
assert existing_token.refresh_token == 'new-refresh-token'
class TestIsAccessTokenValid:
@@ -320,93 +559,80 @@ class TestIsAccessTokenValid:
@pytest.mark.asyncio
async def test_is_access_token_valid_returns_false_when_no_tokens(
self, async_session_maker
self, auth_token_store, mock_session_maker, mock_session
):
"""Test returns False when no tokens found."""
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
result = await store.is_access_token_valid()
result = await auth_token_store.is_access_token_valid()
assert result is False
assert result is False
@pytest.mark.asyncio
async def test_is_access_token_valid_returns_true_for_valid_token(
self, async_session_maker
self, auth_token_store, mock_session_maker, mock_session
):
"""Test returns True when token is valid."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'valid-access'
mock_token.refresh_token = 'valid-refresh'
mock_token.access_token_expires_at = current_time + 1000
mock_token.refresh_token_expires_at = current_time + 10000
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
await store.store_tokens(
access_token='valid-access',
refresh_token='valid-refresh',
access_token_expires_at=current_time + 1000,
refresh_token_expires_at=current_time + 10000,
)
result = await auth_token_store.is_access_token_valid()
result = await store.is_access_token_valid()
assert result is True
assert result is True
@pytest.mark.asyncio
async def test_is_access_token_valid_returns_false_for_expired_token(
self, async_session_maker
self, auth_token_store, mock_session_maker, mock_session
):
"""Test returns False when token is expired."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'expired-access'
mock_token.refresh_token = 'valid-refresh'
mock_token.access_token_expires_at = current_time - 100 # Expired
mock_token.refresh_token_expires_at = current_time + 10000
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
)
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
await store.store_tokens(
access_token='expired-access',
refresh_token='valid-refresh',
access_token_expires_at=current_time - 100, # Expired
refresh_token_expires_at=current_time + 10000,
)
result = await auth_token_store.is_access_token_valid()
result = await store.is_access_token_valid()
assert result is False
assert result is False
class TestGetInstance:
"""Tests for get_instance class method."""
@pytest.mark.asyncio
async def test_get_instance_creates_auth_token_store(self, async_session_maker):
async def test_get_instance_creates_auth_token_store(self):
"""Test get_instance creates an AuthTokenStore with correct params."""
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
with patch('storage.auth_token_store.a_session_maker') as mock_a_session_maker:
store = await AuthTokenStore.get_instance(
keycloak_user_id='user-123', idp=ProviderType.GITHUB
)
assert store.keycloak_user_id == 'user-123'
assert store.idp == ProviderType.GITHUB
assert store.a_session_maker is mock_a_session_maker
class TestIdentityProviderValue:
"""Tests for identity_provider_value property."""
def test_identity_provider_value_returns_idp_value(self):
def test_identity_provider_value_returns_idp_value(self, auth_token_store):
"""Test that identity_provider_value returns the enum value."""
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=ProviderType.GITHUB,
)
assert store.identity_provider_value == ProviderType.GITHUB.value
assert auth_token_store.identity_provider_value == ProviderType.GITHUB.value
def test_identity_provider_value_for_different_providers(self):
"""Test identity_provider_value for different providers."""
@@ -418,6 +644,7 @@ class TestIdentityProviderValue:
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=provider,
a_session_maker=MagicMock(),
)
assert store.identity_provider_value == provider.value

View File

@@ -1,17 +1,33 @@
"""Unit tests for DeviceCodeStore."""
from unittest.mock import patch
from unittest.mock import MagicMock
import pytest
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from storage.device_code import DeviceCode
from storage.device_code_store import DeviceCodeStore
@pytest.fixture
def device_code_store():
def mock_session():
"""Mock database session."""
session = MagicMock()
return session
@pytest.fixture
def mock_session_maker(mock_session):
"""Mock session maker."""
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = mock_session
session_maker.return_value.__exit__.return_value = None
return session_maker
@pytest.fixture
def device_code_store(mock_session_maker):
"""Create DeviceCodeStore instance."""
return DeviceCodeStore()
return DeviceCodeStore(mock_session_maker)
class TestDeviceCodeStore:
@@ -33,257 +49,145 @@ class TestDeviceCodeStore:
assert len(code) == 128
assert code.isalnum()
@pytest.mark.asyncio
async def test_create_device_code_success(
self, device_code_store, async_session_maker
):
def test_create_device_code_success(self, device_code_store, mock_session):
"""Test successful device code creation."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.create_device_code(expires_in=600)
# Mock successful creation (no IntegrityError)
mock_device_code = MagicMock(spec=DeviceCode)
mock_device_code.device_code = 'test-device-code-123'
mock_device_code.user_code = 'TESTCODE'
# Mock the session to return our mock device code after refresh
def mock_refresh(obj):
obj.device_code = mock_device_code.device_code
obj.user_code = mock_device_code.user_code
mock_session.refresh.side_effect = mock_refresh
result = device_code_store.create_device_code(expires_in=600)
assert isinstance(result, DeviceCode)
assert len(result.device_code) == 128
assert len(result.user_code) == 8
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once()
mock_session.expunge.assert_called_once()
# Verify the DeviceCode was created in the database
async with async_session_maker() as session:
result_db = await session.execute(
select(DeviceCode).filter(DeviceCode.device_code == result.device_code)
)
device_code = result_db.scalars().first()
assert device_code is not None
assert device_code.user_code == result.user_code
@pytest.mark.asyncio
async def test_create_device_code_with_retries(
self, device_code_store, async_session_maker
def test_create_device_code_with_retries(
self, device_code_store, mock_session_maker
):
"""Test device code creation with constraint violation retries."""
# First create a device code to cause a collision
with patch('storage.device_code_store.a_session_maker', async_session_maker):
first_code = await device_code_store.create_device_code(expires_in=600)
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
# Patch generate methods to return the same codes on first attempt,
# then different codes on second attempt
call_count = {'user': 0, 'device': 0}
original_generate_user_code = device_code_store.generate_user_code
original_generate_device_code = device_code_store.generate_device_code
# First attempt fails with IntegrityError, second succeeds
mock_session.commit.side_effect = [IntegrityError('', '', ''), None]
def mock_generate_user_code():
call_count['user'] += 1
if call_count['user'] == 1:
return first_code.user_code # Collision
return original_generate_user_code()
mock_device_code = MagicMock(spec=DeviceCode)
mock_device_code.device_code = 'test-device-code-456'
mock_device_code.user_code = 'TESTCD2'
def mock_generate_device_code():
call_count['device'] += 1
if call_count['device'] == 1:
return first_code.device_code # Collision
return original_generate_device_code()
def mock_refresh(obj):
obj.device_code = mock_device_code.device_code
obj.user_code = mock_device_code.user_code
device_code_store.generate_user_code = mock_generate_user_code
device_code_store.generate_device_code = mock_generate_device_code
mock_session.refresh.side_effect = mock_refresh
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.create_device_code(expires_in=600)
store = DeviceCodeStore(mock_session_maker)
result = store.create_device_code(expires_in=600)
assert isinstance(result, DeviceCode)
assert result.device_code != first_code.device_code # Should be different
assert call_count['user'] == 2 # Two attempts
assert mock_session.add.call_count == 2 # Two attempts
assert mock_session.commit.call_count == 2 # Two attempts
@pytest.mark.asyncio
async def test_create_device_code_max_attempts_exceeded(
self, device_code_store, async_session_maker
def test_create_device_code_max_attempts_exceeded(
self, device_code_store, mock_session_maker
):
"""Test device code creation failure after max attempts."""
# First create a device code
with patch('storage.device_code_store.a_session_maker', async_session_maker):
first_code = await device_code_store.create_device_code(expires_in=600)
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
# Always return the same codes to cause repeated collisions
device_code_store.generate_user_code = lambda: first_code.user_code
device_code_store.generate_device_code = lambda: first_code.device_code
# All attempts fail with IntegrityError
mock_session.commit.side_effect = IntegrityError('', '', '')
with patch('storage.device_code_store.a_session_maker', async_session_maker):
with pytest.raises(
RuntimeError,
match='Failed to generate unique device codes after 3 attempts',
):
await device_code_store.create_device_code(
expires_in=600, max_attempts=3
)
store = DeviceCodeStore(mock_session_maker)
@pytest.mark.asyncio
async def test_get_by_device_code(self, device_code_store, async_session_maker):
"""Test getting device code by device code."""
# Create a device code first
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
result = await device_code_store.get_by_device_code(created.device_code)
with pytest.raises(
RuntimeError,
match='Failed to generate unique device codes after 3 attempts',
):
store.create_device_code(expires_in=600, max_attempts=3)
assert result is not None
assert result.device_code == created.device_code
assert result.user_code == created.user_code
@pytest.mark.asyncio
async def test_get_by_device_code_not_found(
self, device_code_store, async_session_maker
@pytest.mark.parametrize(
'lookup_method,lookup_field',
[
('get_by_device_code', 'device_code'),
('get_by_user_code', 'user_code'),
],
)
def test_lookup_methods(
self, device_code_store, mock_session, lookup_method, lookup_field
):
"""Test getting non-existent device code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.get_by_device_code('non-existent-code')
"""Test device code lookup methods."""
test_code = 'test-code-123'
mock_device_code = MagicMock()
mock_session.query.return_value.filter_by.return_value.first.return_value = (
mock_device_code
)
assert result is None
result = getattr(device_code_store, lookup_method)(test_code)
@pytest.mark.asyncio
async def test_get_by_user_code(self, device_code_store, async_session_maker):
"""Test getting device code by user code."""
# Create a device code first
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
result = await device_code_store.get_by_user_code(created.user_code)
assert result == mock_device_code
mock_session.query.assert_called_once_with(DeviceCode)
mock_session.query.return_value.filter_by.assert_called_once_with(
**{lookup_field: test_code}
)
assert result is not None
assert result.device_code == created.device_code
assert result.user_code == created.user_code
@pytest.mark.asyncio
async def test_get_by_user_code_not_found(
self, device_code_store, async_session_maker
@pytest.mark.parametrize(
'device_exists,is_pending,expected_result',
[
(True, True, True), # Success case
(False, True, False), # Device not found
(True, False, False), # Device not pending
],
)
def test_authorize_device_code(
self,
device_code_store,
mock_session,
device_exists,
is_pending,
expected_result,
):
"""Test getting non-existent user code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.get_by_user_code('NOTFOUND')
assert result is None
@pytest.mark.asyncio
async def test_authorize_device_code_success(
self, device_code_store, async_session_maker
):
"""Test successful device code authorization."""
"""Test device code authorization."""
user_code = 'ABC12345'
user_id = 'test-user-123'
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
result = await device_code_store.authorize_device_code(
created.user_code, user_id
)
if device_exists:
mock_device = MagicMock()
mock_device.is_pending.return_value = is_pending
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_device
else:
mock_session.query.return_value.filter_by.return_value.first.return_value = None
result = device_code_store.authorize_device_code(user_code, user_id)
assert result == expected_result
if expected_result:
mock_device.authorize.assert_called_once_with(user_id)
mock_session.commit.assert_called_once()
def test_deny_device_code(self, device_code_store, mock_session):
"""Test device code denial."""
user_code = 'ABC12345'
mock_device = MagicMock()
mock_device.is_pending.return_value = True
mock_session.query.return_value.filter_by.return_value.first.return_value = (
mock_device
)
result = device_code_store.deny_device_code(user_code)
assert result is True
# Verify the device code was authorized in the database
async with async_session_maker() as session:
result_db = await session.execute(
select(DeviceCode).filter(DeviceCode.user_code == created.user_code)
)
device_code = result_db.scalars().first()
assert device_code.status == 'authorized'
assert device_code.keycloak_user_id == user_id
@pytest.mark.asyncio
async def test_authorize_device_code_not_found(
self, device_code_store, async_session_maker
):
"""Test authorizing non-existent device code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.authorize_device_code(
'NOTFOUND', 'user-123'
)
assert result is False
@pytest.mark.asyncio
async def test_authorize_device_code_not_pending(
self, device_code_store, async_session_maker
):
"""Test authorizing already authorized device code."""
user_id = 'test-user-123'
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
# First authorization
await device_code_store.authorize_device_code(created.user_code, user_id)
# Second authorization should fail
result = await device_code_store.authorize_device_code(
created.user_code, 'another-user'
)
assert result is False
@pytest.mark.asyncio
async def test_deny_device_code_success(
self, device_code_store, async_session_maker
):
"""Test successful device code denial."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
result = await device_code_store.deny_device_code(created.user_code)
assert result is True
# Verify the device code was denied in the database
async with async_session_maker() as session:
result_db = await session.execute(
select(DeviceCode).filter(DeviceCode.user_code == created.user_code)
)
device_code = result_db.scalars().first()
assert device_code.status == 'denied'
@pytest.mark.asyncio
async def test_deny_device_code_not_found(
self, device_code_store, async_session_maker
):
"""Test denying non-existent device code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.deny_device_code('NOTFOUND')
assert result is False
@pytest.mark.asyncio
async def test_deny_device_code_not_pending(
self, device_code_store, async_session_maker
):
"""Test denying already denied device code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
# First denial
await device_code_store.deny_device_code(created.user_code)
# Second denial should fail
result = await device_code_store.deny_device_code(created.user_code)
assert result is False
@pytest.mark.asyncio
async def test_update_poll_time_success(
self, device_code_store, async_session_maker
):
"""Test updating poll time."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
created = await device_code_store.create_device_code(expires_in=600)
original_interval = created.current_interval
result = await device_code_store.update_poll_time(
created.device_code, increase_interval=True
)
assert result is True
# Verify the poll time was updated
async with async_session_maker() as session:
result_db = await session.execute(
select(DeviceCode).filter(DeviceCode.device_code == created.device_code)
)
device_code = result_db.scalars().first()
assert device_code.current_interval > original_interval
@pytest.mark.asyncio
async def test_update_poll_time_not_found(
self, device_code_store, async_session_maker
):
"""Test updating poll time for non-existent device code."""
with patch('storage.device_code_store.a_session_maker', async_session_maker):
result = await device_code_store.update_poll_time(
'non-existent-code', increase_interval=False
)
assert result is False
mock_device.deny.assert_called_once()
mock_session.commit.assert_called_once()

View File

@@ -9,35 +9,16 @@ from storage.base import Base
from storage.gitlab_webhook import GitlabWebhook
from storage.gitlab_webhook_store import GitlabWebhookStore
# Use module-scoped engine to share database across fixtures
_test_engine = None
@pytest.fixture(scope='function')
def event_loop():
"""Create an instance of the default event loop for each test case."""
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope='function')
async def async_engine(event_loop):
"""Create an async SQLite engine for testing.
This fixture creates an in-memory SQLite database and ensures
all tables are created before tests run.
"""
global _test_engine
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
_test_engine = engine
# Create all tables
async with engine.begin() as conn:
@@ -48,7 +29,7 @@ async def async_engine(event_loop):
await engine.dispose()
@pytest.fixture(scope='function')
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker for testing."""
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
@@ -56,21 +37,8 @@ async def async_session_maker(async_engine):
@pytest.fixture
async def webhook_store(async_session_maker):
"""Create a GitlabWebhookStore instance for testing.
This fixture injects the test's async_session_maker to ensure
the store uses the same in-memory database as the test fixtures.
"""
# Import here to avoid circular imports
store = GitlabWebhookStore()
# Inject the test session maker - this needs to replace the module-level import
import storage.gitlab_webhook_store as store_module
store_module.a_session_maker = async_session_maker
return store
"""Create a GitlabWebhookStore instance for testing."""
return GitlabWebhookStore(a_session_maker=async_session_maker)
@pytest.fixture
@@ -134,7 +102,7 @@ class TestGetWebhookByResourceOnly:
@pytest.mark.asyncio
async def test_get_project_webhook_by_resource_only(
self, webhook_store, sample_webhooks
self, webhook_store, async_session_maker, sample_webhooks
):
"""Test getting a project webhook by resource ID without user_id filter."""
# Arrange

View File

@@ -1,232 +0,0 @@
"""
Tests for JiraIntegrationStore async methods.
The store uses async database sessions (a_session_maker) for all operations,
which is critical for avoiding asyncpg event loop issues when called from
FastAPI async endpoints.
"""
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, Mock, patch
import pytest
from storage.jira_integration_store import JiraIntegrationStore
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
@pytest.fixture
def store():
"""Create a JiraIntegrationStore instance."""
return JiraIntegrationStore()
@pytest.fixture
def create_mock_async_session():
"""Factory to create properly mocked async session context manager."""
def _create(query_result=None, all_results=None):
mock_session = Mock()
mock_result = Mock()
if all_results is not None:
mock_result.scalars.return_value.all.return_value = all_results
else:
mock_result.scalars.return_value.first.return_value = query_result
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.add = Mock()
mock_session.commit = AsyncMock()
mock_session.refresh = AsyncMock()
@asynccontextmanager
async def mock_context_manager():
yield mock_session
return mock_context_manager, mock_session
return _create
class TestJiraIntegrationStoreAsyncMethods:
"""Tests verifying JiraIntegrationStore methods use async sessions correctly."""
@pytest.mark.asyncio
async def test_get_workspace_by_id_returns_workspace(
self, store, create_mock_async_session
):
"""Test get_workspace_by_id returns workspace when found."""
# Arrange
mock_workspace = Mock(spec=JiraWorkspace)
mock_workspace.id = 1
mock_workspace.name = 'test-workspace'
mock_context_manager, mock_session = create_mock_async_session(mock_workspace)
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
result = await store.get_workspace_by_id(1)
# Assert
assert result == mock_workspace
mock_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_workspace_by_id_returns_none_when_not_found(
self, store, create_mock_async_session
):
"""Test get_workspace_by_id returns None when workspace not found."""
# Arrange
mock_context_manager, mock_session = create_mock_async_session(None)
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
result = await store.get_workspace_by_id(999)
# Assert
assert result is None
@pytest.mark.asyncio
async def test_get_workspace_by_name_normalizes_to_lowercase(
self, store, create_mock_async_session
):
"""Test get_workspace_by_name converts name to lowercase for query."""
# Arrange
mock_workspace = Mock(spec=JiraWorkspace)
mock_workspace.name = 'test-workspace'
mock_context_manager, mock_session = create_mock_async_session(mock_workspace)
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
result = await store.get_workspace_by_name('TEST-WORKSPACE')
# Assert
assert result == mock_workspace
# Verify the query was executed (filter includes lowercase conversion)
mock_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_get_active_user_filters_by_status(
self, store, create_mock_async_session
):
"""Test get_active_user only returns users with active status."""
# Arrange
mock_user = Mock(spec=JiraUser)
mock_user.jira_user_id = 'jira-123'
mock_user.jira_workspace_id = 1
mock_user.status = 'active'
mock_context_manager, mock_session = create_mock_async_session(mock_user)
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
result = await store.get_active_user('jira-123', 1)
# Assert
assert result == mock_user
mock_session.execute.assert_called_once()
@pytest.mark.asyncio
async def test_create_workspace_adds_and_commits(
self, store, create_mock_async_session
):
"""Test create_workspace properly adds, commits, and refreshes."""
# Arrange
mock_context_manager, mock_session = create_mock_async_session(None)
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
await store.create_workspace(
name='TEST-WORKSPACE',
jira_cloud_id='cloud-123',
admin_user_id='admin-user',
encrypted_webhook_secret='encrypted-secret',
svc_acc_email='svc@test.com',
encrypted_svc_acc_api_key='encrypted-key',
status='active',
)
# Assert
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once()
# Verify workspace was created with lowercase name
added_workspace = mock_session.add.call_args[0][0]
assert added_workspace.name == 'test-workspace'
@pytest.mark.asyncio
async def test_update_user_integration_status_raises_if_not_found(
self, store, create_mock_async_session
):
"""Test update_user_integration_status raises ValueError if user not found."""
# Arrange
mock_context_manager, mock_session = create_mock_async_session(None)
# Act & Assert
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
with pytest.raises(ValueError) as exc_info:
await store.update_user_integration_status('unknown-user', 'inactive')
assert 'Jira user not found' in str(exc_info.value)
@pytest.mark.asyncio
async def test_deactivate_workspace_deactivates_all_users(
self, store, create_mock_async_session
):
"""Test deactivate_workspace sets all users and workspace to inactive."""
# Arrange
mock_user1 = Mock(spec=JiraUser)
mock_user1.status = 'active'
mock_user2 = Mock(spec=JiraUser)
mock_user2.status = 'active'
mock_workspace = Mock(spec=JiraWorkspace)
mock_workspace.status = 'active'
mock_session = Mock()
# First execute returns users, second returns workspace
call_count = [0]
def execute_side_effect(*args, **kwargs):
result = Mock()
if call_count[0] == 0:
result.scalars.return_value.all.return_value = [mock_user1, mock_user2]
else:
result.scalars.return_value.first.return_value = mock_workspace
call_count[0] += 1
return result
mock_session.execute = AsyncMock(side_effect=execute_side_effect)
mock_session.add = Mock()
mock_session.commit = AsyncMock()
@asynccontextmanager
async def mock_context_manager():
yield mock_session
# Act
with patch(
'storage.jira_integration_store.a_session_maker', mock_context_manager
):
await store.deactivate_workspace(1)
# Assert
assert mock_user1.status == 'inactive'
assert mock_user2.status == 'inactive'
assert mock_workspace.status == 'inactive'
mock_session.commit.assert_called_once()

View File

@@ -5,15 +5,21 @@ Tests the async database operations for organization app settings.
"""
import uuid
from unittest.mock import patch
import pytest
from server.routes.org_models import OrgAppSettingsUpdate
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.org import Org
from storage.org_app_settings_store import OrgAppSettingsStore
from storage.user import User
# Mock the database module before importing
with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from server.routes.org_models import OrgAppSettingsUpdate
from storage.base import Base
from storage.org import Org
from storage.org_app_settings_store import OrgAppSettingsStore
from storage.user import User
@pytest.fixture

View File

@@ -8,13 +8,18 @@ import uuid
from unittest.mock import AsyncMock, patch
import pytest
from server.routes.org_models import OrgLLMSettingsUpdate
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.org import Org
from storage.org_llm_settings_store import OrgLLMSettingsStore
from storage.user import User
# Mock the database module before importing
with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from server.routes.org_models import OrgLLMSettingsUpdate
from storage.base import Base
from storage.org import Org
from storage.org_llm_settings_store import OrgLLMSettingsStore
from storage.user import User
@pytest.fixture

View File

@@ -5,15 +5,21 @@ Tests the async database operations for user app settings.
"""
import uuid
from unittest.mock import patch
import pytest
from server.routes.user_app_settings_models import UserAppSettingsUpdate
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.org import Org
from storage.user import User
from storage.user_app_settings_store import UserAppSettingsStore
# Mock the database module before importing
with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from server.routes.user_app_settings_models import UserAppSettingsUpdate
from storage.base import Base
from storage.org import Org
from storage.user import User
from storage.user_app_settings_store import UserAppSettingsStore
@pytest.fixture

View File

@@ -1,49 +1,40 @@
import uuid
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import select
from storage.api_key import ApiKey
from storage.api_key_store import ApiKeyStore
@pytest.fixture
def mock_session():
session = MagicMock()
return session
@pytest.fixture
def mock_session_maker(mock_session):
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = mock_session
session_maker.return_value.__exit__.return_value = None
return session_maker
@pytest.fixture
def mock_user():
"""Mock user with org_id."""
user = MagicMock()
user.current_org_id = uuid.uuid4()
user.current_org_id = 'test-org-123'
return user
@pytest.fixture
def api_key_store():
return ApiKeyStore()
def api_key_store(mock_session_maker):
return ApiKeyStore(mock_session_maker)
@pytest.fixture
def mock_litellm_api():
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
api_url_patch = patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
)
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
client_patch = patch('httpx.AsyncClient')
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
mock_response = AsyncMock()
mock_response.is_success = True
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_response
)
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_response
)
mock_client.return_value.__aenter__.return_value.patch.return_value = (
mock_response
)
yield mock_client
def run_sync(func, *args, **kwargs):
"""Helper to execute sync functions directly (mocks call_sync_from_async)."""
return func(*args, **kwargs)
def test_generate_api_key(api_key_store):
@@ -56,445 +47,294 @@ def test_generate_api_key(api_key_store):
@pytest.mark.asyncio
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
async def test_create_api_key(
mock_get_user, api_key_store, async_session_maker, mock_user
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
):
"""Test creating an API key."""
# Setup
user_id = str(uuid.uuid4())
user_id = 'test-user-123'
name = 'Test Key'
mock_get_user.return_value = mock_user
api_key_store.generate_api_key = MagicMock(return_value='test-api-key')
# Patch a_session_maker in the api_key_store module to use the test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
# Execute
result = await api_key_store.create_api_key(user_id, name)
# Verify
assert result.startswith('sk-oh-')
mock_get_user.assert_called_once_with(user_id)
# Verify the ApiKey was created in the database using async session
async with async_session_maker() as session:
result_db = await session.execute(
select(ApiKey).filter(ApiKey.user_id == user_id)
)
api_key = result_db.scalars().first()
assert api_key is not None
assert api_key.name == name
assert api_key.org_id == mock_user.current_org_id
@pytest.mark.asyncio
async def test_validate_api_key_valid(api_key_store, async_session_maker):
"""Test validating a valid API key."""
# Setup - create an API key in the database
user_id = str(uuid.uuid4())
org_id = uuid.uuid4()
api_key_value = 'test-api-key'
async with async_session_maker() as session:
key_record = ApiKey(
key=api_key_value,
user_id=user_id,
org_id=org_id,
name='Test Key',
expires_at=None,
)
session.add(key_record)
await session.commit()
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.validate_api_key(api_key_value)
# Verify
assert result == user_id
@pytest.mark.asyncio
async def test_validate_api_key_expired(api_key_store, async_session_maker):
"""Test validating an expired API key."""
# Setup - create an expired API key in the database
user_id = str(uuid.uuid4())
org_id = uuid.uuid4()
api_key_value = 'test-expired-key'
async with async_session_maker() as session:
key_record = ApiKey(
key=api_key_value,
user_id=user_id,
org_id=org_id,
name='Test Key',
expires_at=datetime.now(UTC) - timedelta(days=1),
)
session.add(key_record)
await session.commit()
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.validate_api_key(api_key_value)
# Verify
assert result is None
@pytest.mark.asyncio
async def test_validate_api_key_expired_timezone_naive(
api_key_store, async_session_maker
):
"""Test validating an expired API key with timezone-naive datetime from database."""
# Setup - create an expired API key with timezone-naive datetime
user_id = str(uuid.uuid4())
org_id = uuid.uuid4()
api_key_value = 'test-expired-naive-key'
async with async_session_maker() as session:
key_record = ApiKey(
key=api_key_value,
user_id=user_id,
org_id=org_id,
name='Test Key',
# Timezone-naive datetime (database stores this)
expires_at=datetime.now() - timedelta(days=1),
)
session.add(key_record)
await session.commit()
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.validate_api_key(api_key_value)
# Verify
assert result is None
@pytest.mark.asyncio
async def test_validate_api_key_valid_timezone_naive(
api_key_store, async_session_maker
):
"""Test validating a valid API key with timezone-naive datetime from database."""
# Setup - create a valid API key with timezone-naive datetime (future date)
user_id = str(uuid.uuid4())
org_id = uuid.uuid4()
api_key_value = 'test-valid-naive-key'
async with async_session_maker() as session:
key_record = ApiKey(
key=api_key_value,
user_id=user_id,
org_id=org_id,
name='Test Key',
# Timezone-naive datetime in the future
expires_at=datetime.now() + timedelta(days=1),
)
session.add(key_record)
await session.commit()
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.validate_api_key(api_key_value)
# Verify
assert result == user_id
@pytest.mark.asyncio
async def test_validate_api_key_not_found(api_key_store, async_session_maker):
"""Test validating a non-existent API key."""
# Execute
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.validate_api_key('non-existent-key')
result = await api_key_store.create_api_key(user_id, name)
# Verify
assert result == 'test-api-key'
mock_get_user.assert_called_once_with(user_id)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
api_key_store.generate_api_key.assert_called_once()
# Verify the ApiKey was created with the correct org_id
added_api_key = mock_session.add.call_args[0][0]
assert added_api_key.org_id == mock_user.current_org_id
def test_validate_api_key_valid(api_key_store, mock_session):
"""Test validating a valid API key."""
# Setup
api_key = 'test-api-key'
user_id = 'test-user-123'
mock_key_record = MagicMock()
mock_key_record.user_id = user_id
mock_key_record.expires_at = None
mock_key_record.id = 1
mock_session.query.return_value.filter.return_value.first.return_value = (
mock_key_record
)
# Execute
result = api_key_store.validate_api_key(api_key)
# Verify
assert result == user_id
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_validate_api_key_expired(api_key_store, mock_session):
"""Test validating an expired API key."""
# Setup
api_key = 'test-api-key'
mock_key_record = MagicMock()
mock_key_record.expires_at = datetime.now(UTC) - timedelta(days=1)
mock_key_record.id = 1
mock_session.query.return_value.filter.return_value.first.return_value = (
mock_key_record
)
# Execute
result = api_key_store.validate_api_key(api_key)
# Verify
assert result is None
mock_session.execute.assert_not_called()
mock_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_delete_api_key(api_key_store, async_session_maker):
def test_validate_api_key_expired_timezone_naive(api_key_store, mock_session):
"""Test validating an expired API key with timezone-naive datetime from database."""
# Setup
api_key = 'test-api-key'
mock_key_record = MagicMock()
# Simulate timezone-naive datetime as returned from database
mock_key_record.expires_at = datetime.now() - timedelta(days=1) # No UTC timezone
mock_key_record.id = 1
mock_session.query.return_value.filter.return_value.first.return_value = (
mock_key_record
)
# Execute
result = api_key_store.validate_api_key(api_key)
# Verify
assert result is None
mock_session.execute.assert_not_called()
mock_session.commit.assert_not_called()
def test_validate_api_key_valid_timezone_naive(api_key_store, mock_session):
"""Test validating a valid API key with timezone-naive datetime from database."""
# Setup
api_key = 'test-api-key'
user_id = 'test-user-123'
mock_key_record = MagicMock()
mock_key_record.user_id = user_id
# Simulate timezone-naive datetime as returned from database (future date)
mock_key_record.expires_at = datetime.now() + timedelta(days=1) # No UTC timezone
mock_key_record.id = 1
mock_session.query.return_value.filter.return_value.first.return_value = (
mock_key_record
)
# Execute
result = api_key_store.validate_api_key(api_key)
# Verify
assert result == user_id
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_validate_api_key_not_found(api_key_store, mock_session):
"""Test validating a non-existent API key."""
# Setup
api_key = 'test-api-key'
query_result = mock_session.query.return_value.filter.return_value
query_result.first.return_value = None
# Execute
result = api_key_store.validate_api_key(api_key)
# Verify
assert result is None
mock_session.execute.assert_not_called()
mock_session.commit.assert_not_called()
def test_delete_api_key(api_key_store, mock_session):
"""Test deleting an API key."""
# Setup - create an API key in the database
user_id = str(uuid.uuid4())
org_id = uuid.uuid4()
api_key_value = 'test-delete-key'
# Setup
api_key = 'test-api-key'
mock_key_record = MagicMock()
mock_session.query.return_value.filter.return_value.first.return_value = (
mock_key_record
)
async with async_session_maker() as session:
key_record = ApiKey(
key=api_key_value,
user_id=user_id,
org_id=org_id,
name='Test Key',
)
session.add(key_record)
await session.commit()
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.delete_api_key(api_key_value)
# Execute
result = api_key_store.delete_api_key(api_key)
# Verify
assert result is True
# Verify it was deleted from the database
async with async_session_maker() as session:
result_db = await session.execute(
select(ApiKey).filter(ApiKey.key == api_key_value)
)
api_key = result_db.scalars().first()
assert api_key is None
mock_session.delete.assert_called_once_with(mock_key_record)
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_api_key_not_found(api_key_store, async_session_maker):
def test_delete_api_key_not_found(api_key_store, mock_session):
"""Test deleting a non-existent API key."""
# Setup
api_key = 'test-api-key'
query_result = mock_session.query.return_value.filter.return_value
query_result.first.return_value = None
# Execute
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.delete_api_key('non-existent-key')
result = api_key_store.delete_api_key(api_key)
# Verify
assert result is False
mock_session.delete.assert_not_called()
mock_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_delete_api_key_by_id(api_key_store, async_session_maker):
def test_delete_api_key_by_id(api_key_store, mock_session):
"""Test deleting an API key by ID."""
# Setup - create an API key in the database
user_id = str(uuid.uuid4())
org_id = uuid.uuid4()
# Setup
key_id = 123
mock_key_record = MagicMock()
mock_session.query.return_value.filter.return_value.first.return_value = (
mock_key_record
)
async with async_session_maker() as session:
key_record = ApiKey(
key='test-delete-by-id-key',
user_id=user_id,
org_id=org_id,
name='Test Key',
)
session.add(key_record)
await session.commit()
key_id = key_record.id
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.delete_api_key_by_id(key_id)
# Execute
result = api_key_store.delete_api_key_by_id(key_id)
# Verify
assert result is True
# Verify it was deleted from the database
async with async_session_maker() as session:
result_db = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
api_key = result_db.scalars().first()
assert api_key is None
mock_session.delete.assert_called_once_with(mock_key_record)
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
async def test_list_api_keys(
mock_get_user, api_key_store, async_session_maker, mock_user
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
):
"""Test listing API keys for a user."""
# Setup
user_id = str(uuid.uuid4())
user_id = 'test-user-123'
mock_get_user.return_value = mock_user
now = datetime.now(UTC)
mock_key1 = MagicMock()
mock_key1.id = 1
mock_key1.name = 'Key 1'
mock_key1.created_at = now
mock_key1.last_used_at = now
mock_key1.expires_at = now + timedelta(days=30)
# Create API keys in the database
async with async_session_maker() as session:
key1 = ApiKey(
key='test-key-1',
user_id=user_id,
org_id=mock_user.current_org_id,
name='Key 1',
created_at=now,
last_used_at=now,
expires_at=now + timedelta(days=30),
)
key2 = ApiKey(
key='test-key-2',
user_id=user_id,
org_id=mock_user.current_org_id,
name='Key 2',
created_at=now,
last_used_at=None,
expires_at=None,
)
# Add an MCP key that should be filtered out
mcp_key = ApiKey(
key='test-mcp-key',
user_id=user_id,
org_id=mock_user.current_org_id,
name='MCP_API_KEY',
created_at=now,
)
session.add_all([key1, key2, mcp_key])
await session.commit()
mock_key2 = MagicMock()
mock_key2.id = 2
mock_key2.name = 'Key 2'
mock_key2.created_at = now
mock_key2.last_used_at = None
mock_key2.expires_at = None
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.list_api_keys(user_id)
# Mock the chained query calls for filtering by user_id and org_id
mock_query = mock_session.query.return_value
mock_filter_user = mock_query.filter.return_value
mock_filter_org = mock_filter_user.filter.return_value
mock_filter_org.all.return_value = [mock_key1, mock_key2]
# Execute
result = await api_key_store.list_api_keys(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)
assert len(result) == 2
assert result[0].id == 1
assert result[0].name == 'Key 1'
assert result[0].created_at == now
assert result[0].last_used_at == now
assert result[0].expires_at == now + timedelta(days=30)
assert result[1].id == 2
assert result[1].name == 'Key 2'
assert result[1].created_at == now
assert result[1].last_used_at is None
assert result[1].expires_at is None
@pytest.mark.asyncio
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
async def test_retrieve_mcp_api_key(
mock_get_user, api_key_store, async_session_maker, mock_user
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
):
"""Test retrieving MCP API key for a user."""
# Setup
user_id = str(uuid.uuid4())
user_id = 'test-user-123'
mock_get_user.return_value = mock_user
now = datetime.now(UTC)
# Create API keys in the database
async with async_session_maker() as session:
other_key = ApiKey(
key='test-other-key',
user_id=user_id,
org_id=mock_user.current_org_id,
name='Other Key',
created_at=now,
)
mcp_key = ApiKey(
key='test-mcp-key',
user_id=user_id,
org_id=mock_user.current_org_id,
name='MCP_API_KEY',
created_at=now,
)
session.add_all([other_key, mcp_key])
await session.commit()
mock_mcp_key = MagicMock()
mock_mcp_key.name = 'MCP_API_KEY'
mock_mcp_key.key = 'mcp-test-key'
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.retrieve_mcp_api_key(user_id)
mock_other_key = MagicMock()
mock_other_key.name = 'Other Key'
mock_other_key.key = 'other-test-key'
# Mock the chained query calls for filtering by user_id and org_id
mock_query = mock_session.query.return_value
mock_filter_user = mock_query.filter.return_value
mock_filter_org = mock_filter_user.filter.return_value
mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key]
# Execute
result = await api_key_store.retrieve_mcp_api_key(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)
assert result == 'test-mcp-key'
assert result == 'mcp-test-key'
@pytest.mark.asyncio
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
async def test_retrieve_mcp_api_key_not_found(
mock_get_user, api_key_store, async_session_maker, mock_user
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
):
"""Test retrieving MCP API key when none exists."""
# Setup
user_id = str(uuid.uuid4())
user_id = 'test-user-123'
mock_get_user.return_value = mock_user
now = datetime.now(UTC)
# Create only non-MCP keys in the database
async with async_session_maker() as session:
other_key = ApiKey(
key='test-other-key',
user_id=user_id,
org_id=mock_user.current_org_id,
name='Other Key',
created_at=now,
)
session.add(other_key)
await session.commit()
mock_other_key = MagicMock()
mock_other_key.name = 'Other Key'
mock_other_key.key = 'other-test-key'
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.retrieve_mcp_api_key(user_id)
# Mock the chained query calls for filtering by user_id and org_id
mock_query = mock_session.query.return_value
mock_filter_user = mock_query.filter.return_value
mock_filter_org = mock_filter_user.filter.return_value
mock_filter_org.all.return_value = [mock_other_key]
# Execute
result = await api_key_store.retrieve_mcp_api_key(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)
assert result is None
@pytest.mark.asyncio
async def test_retrieve_api_key_by_name(api_key_store, async_session_maker):
"""Test retrieving an API key by name."""
# Setup
user_id = str(uuid.uuid4())
org_id = uuid.uuid4()
key_name = 'Test Key'
key_value = 'test-key-by-name'
async with async_session_maker() as session:
key_record = ApiKey(
key=key_value,
user_id=user_id,
org_id=org_id,
name=key_name,
)
session.add(key_record)
await session.commit()
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.retrieve_api_key_by_name(user_id, key_name)
# Verify
assert result == key_value
@pytest.mark.asyncio
async def test_retrieve_api_key_by_name_not_found(api_key_store, async_session_maker):
"""Test retrieving an API key by name that doesn't exist."""
# Execute
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.retrieve_api_key_by_name(
'non-existent-user', 'Non Existent Key'
)
# Verify
assert result is None
@pytest.mark.asyncio
async def test_delete_api_key_by_name(api_key_store, async_session_maker):
"""Test deleting an API key by name."""
# Setup
user_id = str(uuid.uuid4())
org_id = uuid.uuid4()
key_name = 'Test Key to Delete'
key_value = 'test-delete-by-name'
async with async_session_maker() as session:
key_record = ApiKey(
key=key_value,
user_id=user_id,
org_id=org_id,
name=key_name,
)
session.add(key_record)
await session.commit()
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.delete_api_key_by_name(user_id, key_name)
# Verify
assert result is True
# Verify it was deleted from the database
async with async_session_maker() as session:
result_db = await session.execute(
select(ApiKey).filter(ApiKey.key == key_value)
)
api_key = result_db.scalars().first()
assert api_key is None
@pytest.mark.asyncio
async def test_delete_api_key_by_name_not_found(api_key_store, async_session_maker):
"""Test deleting an API key by name that doesn't exist."""
# Execute
with patch('storage.api_key_store.a_session_maker', async_session_maker):
result = await api_key_store.delete_api_key_by_name(
'non-existent-user', 'Non Existent Key'
)
# Verify
assert result is False

View File

@@ -595,7 +595,7 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
mock_user_store.backfill_user_email = AsyncMock()
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=True)
mock_domain_blocker.is_domain_blocked.return_value = True
# Act
result = await keycloak_callback(
@@ -621,7 +621,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
@@ -660,7 +660,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
mock_user_store.backfill_user_email = AsyncMock()
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -686,7 +686,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
@@ -725,7 +725,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
mock_user_store.backfill_user_email = AsyncMock()
mock_domain_blocker.is_active.return_value = False
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
@@ -749,7 +749,7 @@ async def test_keycloak_callback_missing_email(mock_request):
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_session = MagicMock()
@@ -898,7 +898,7 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
# Arrange
@@ -959,7 +959,7 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
# Arrange
@@ -1022,7 +1022,7 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
):
# Arrange
@@ -1174,7 +1174,7 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
@@ -1221,7 +1221,7 @@ class TestKeycloakCallbackRecaptcha:
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
# Patch the module-level recaptcha_service instance
mock_recaptcha_service.create_assessment.return_value = (
@@ -1284,7 +1284,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
# Patch the module-level recaptcha_service instance
mock_recaptcha_service.create_assessment.return_value = (
@@ -1325,7 +1325,7 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
@@ -1371,7 +1371,7 @@ class TestKeycloakCallbackRecaptcha:
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
# Patch the module-level recaptcha_service instance
mock_recaptcha_service.create_assessment.return_value = (
@@ -1414,7 +1414,7 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
@@ -1460,7 +1460,7 @@ class TestKeycloakCallbackRecaptcha:
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
# Patch the module-level recaptcha_service instance
mock_recaptcha_service.create_assessment.return_value = (
@@ -1500,7 +1500,7 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
@@ -1546,7 +1546,7 @@ class TestKeycloakCallbackRecaptcha:
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
# Patch the module-level recaptcha_service instance
mock_recaptcha_service.create_assessment.return_value = (
@@ -1585,7 +1585,7 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
patch('server.routes.email.verify_email', new_callable=AsyncMock),
@@ -1631,7 +1631,7 @@ class TestKeycloakCallbackRecaptcha:
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
# Patch the module-level recaptcha_service instance
mock_recaptcha_service.create_assessment.return_value = (
@@ -1666,7 +1666,7 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
patch('server.routes.auth.RECAPTCHA_SITE_KEY', ''),
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
@@ -1713,7 +1713,7 @@ class TestKeycloakCallbackRecaptcha:
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
# Act
await keycloak_callback(
@@ -1734,7 +1734,7 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
@@ -1781,7 +1781,7 @@ class TestKeycloakCallbackRecaptcha:
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
# Act
await keycloak_callback(code='test_code', state=state, request=mock_request)
@@ -1808,7 +1808,7 @@ class TestKeycloakCallbackRecaptcha:
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.a_session_maker') as mock_session_maker,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.set_response_cookie'),
patch('server.routes.auth.posthog'),
@@ -1855,7 +1855,7 @@ class TestKeycloakCallbackRecaptcha:
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
mock_recaptcha_service.create_assessment.side_effect = Exception(
'Service error'
@@ -1924,7 +1924,7 @@ class TestKeycloakCallbackRecaptcha:
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
mock_domain_blocker.is_domain_blocked.return_value = False
# Patch the module-level recaptcha_service instance
mock_recaptcha_service.create_assessment.return_value = (

View File

@@ -6,7 +6,6 @@ import pytest
import stripe
from fastapi import HTTPException, Request, status
from httpx import Response
from server.constants import ORG_SETTINGS_VERSION
from server.routes import billing
from server.routes.billing import (
CreateBillingSessionResponse,
@@ -19,11 +18,22 @@ from server.routes.billing import (
has_payment_method,
success_callback,
)
from sqlalchemy import select
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from starlette.datastructures import URL
from storage.billing_session import BillingSession
from storage.org import Org
from storage.user import User
from storage.stripe_customer import Base as StripeCustomerBase
@pytest.fixture
def engine():
engine = create_engine('sqlite:///:memory:')
StripeCustomerBase.metadata.create_all(engine)
return engine
@pytest.fixture
def session_maker(engine):
return sessionmaker(bind=engine)
@pytest.fixture
@@ -66,38 +76,6 @@ def mock_subscription_request():
return request
@pytest.fixture
async def test_org(async_session_maker):
"""Create a test org in the database."""
org_id = uuid.uuid4()
async with async_session_maker() as session:
org = Org(
id=org_id,
name=f'test-org-{org_id}',
org_version=ORG_SETTINGS_VERSION,
enable_default_condenser=True,
enable_proactive_conversation_starters=True,
)
session.add(org)
await session.commit()
return org
@pytest.fixture
async def test_user(async_session_maker, test_org):
"""Create a test user in the database linked to test_org."""
user_id = uuid.uuid4()
async with async_session_maker() as session:
user = User(
id=user_id,
current_org_id=test_org.id,
user_consents_to_analytics=True,
)
session.add(user)
await session.commit()
return user
@pytest.mark.asyncio
async def test_get_credits_lite_llm_error():
with (
@@ -155,14 +133,17 @@ async def test_get_credits_success():
@pytest.mark.asyncio
async def test_create_checkout_session_stripe_error(
async_session_maker, mock_checkout_request, test_org
session_maker, mock_checkout_request
):
"""Test handling of Stripe API errors."""
mock_customer = stripe.Customer(
id='mock-customer', metadata={'user_id': 'mock-user'}
)
mock_customer_create = AsyncMock(return_value=mock_customer)
mock_org = MagicMock()
mock_org.id = uuid.uuid4()
mock_org.contact_email = 'testy@tester.com'
with (
pytest.raises(Exception, match='Stripe API Error'),
patch('stripe.Customer.create_async', mock_customer_create),
@@ -173,13 +154,10 @@ async def test_create_checkout_session_stripe_error(
'stripe.checkout.Session.create_async',
AsyncMock(side_effect=Exception('Stripe API Error')),
),
patch('server.routes.billing.a_session_maker', async_session_maker),
patch('integrations.stripe_service.a_session_maker', async_session_maker),
patch('storage.database.a_session_maker', async_session_maker),
patch('storage.org_store.a_session_maker', async_session_maker),
patch('integrations.stripe_service.session_maker', session_maker),
patch(
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
return_value=test_org,
return_value=mock_org,
),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
@@ -193,27 +171,44 @@ async def test_create_checkout_session_stripe_error(
@pytest.mark.asyncio
async def test_create_checkout_session_success(
async_session_maker, mock_checkout_request, test_org
):
async def test_create_checkout_session_success(session_maker, mock_checkout_request):
"""Test successful creation of checkout session."""
mock_session = MagicMock()
mock_session.url = 'https://checkout.stripe.com/test-session'
mock_session.id = 'test_session_id_checkout'
mock_session.id = 'test_session_id'
mock_create = AsyncMock(return_value=mock_session)
mock_create.return_value = mock_session
mock_customer_info = {'customer_id': 'mock-customer', 'org_id': test_org.id}
mock_customer = stripe.Customer(
id='mock-customer', metadata={'user_id': 'mock-user'}
)
mock_customer_create = AsyncMock(return_value=mock_customer)
mock_org = MagicMock()
mock_org_id = uuid.uuid4()
mock_org.id = mock_org_id
mock_org.contact_email = 'testy@tester.com'
with (
patch('stripe.checkout.Session.create_async', mock_create),
patch('server.routes.billing.a_session_maker', async_session_maker),
patch('integrations.stripe_service.a_session_maker', async_session_maker),
patch('stripe.Customer.create_async', mock_customer_create),
patch(
'integrations.stripe_service.find_or_create_customer_by_user_id',
AsyncMock(return_value=mock_customer_info),
'stripe.Customer.search_async', AsyncMock(return_value=MagicMock(data=[]))
),
patch('stripe.checkout.Session.create_async', mock_create),
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('integrations.stripe_service.session_maker', session_maker),
patch(
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
return_value=mock_org,
),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('server.routes.billing.validate_billing_enabled'),
):
mock_db_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_db_session
result = await create_checkout_session(
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
)
@@ -245,102 +240,74 @@ async def test_create_checkout_session_success(
cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
)
# Verify database record was created
async with async_session_maker() as session:
result_db = await session.execute(
select(BillingSession).where(
BillingSession.id == 'test_session_id_checkout'
)
)
billing_session = result_db.scalar_one_or_none()
assert billing_session is not None
assert billing_session.user_id == 'mock_user'
assert billing_session.org_id == test_org.id
assert billing_session.status == 'in_progress'
assert float(billing_session.price) == 25.0
# Verify database session creation
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_success_callback_session_not_found(async_session_maker):
async def test_success_callback_session_not_found():
"""Test success callback when billing session is not found."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
with (
patch('server.routes.billing.a_session_maker', async_session_maker),
patch('stripe.checkout.Session.retrieve'),
):
with patch('server.routes.billing.session_maker') as mock_session_maker:
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = None
mock_session_maker.return_value.__enter__.return_value = mock_db_session
with pytest.raises(HTTPException) as exc_info:
await success_callback('nonexistent_session_id', mock_request)
await success_callback('test_session_id', mock_request)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
mock_db_session.merge.assert_not_called()
mock_db_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_success_callback_stripe_incomplete(
async_session_maker, test_org, test_user
):
async def test_success_callback_stripe_incomplete():
"""Test success callback when Stripe session is not complete."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
session_id = 'test_incomplete_session'
async with async_session_maker() as session:
billing_session = BillingSession(
id=session_id,
user_id=str(test_user.id),
org_id=test_org.id,
status='in_progress',
price=25,
price_code='NA',
)
session.add(billing_session)
await session.commit()
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
with (
patch('server.routes.billing.a_session_maker', async_session_maker),
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
):
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(status='pending')
with pytest.raises(HTTPException) as exc_info:
await success_callback(session_id, mock_request)
await success_callback('test_session_id', mock_request)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
# Verify no database update occurred
async with async_session_maker() as session:
result = await session.execute(
select(BillingSession).where(BillingSession.id == session_id)
)
billing_session = result.scalar_one_or_none()
assert billing_session.status == 'in_progress'
mock_db_session.merge.assert_not_called()
mock_db_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_success_callback_success(async_session_maker, test_org, test_user):
async def test_success_callback_success():
"""Test successful payment completion and credit update."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
session_id = 'test_success_session'
async with async_session_maker() as session:
billing_session = BillingSession(
id=session_id,
user_id=str(test_user.id),
org_id=test_org.id,
status='in_progress',
price=25,
price_code='NA',
)
session.add(billing_session)
await session.commit()
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_org = MagicMock()
with (
patch('server.routes.billing.a_session_maker', async_session_maker),
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch(
'storage.user_store.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id=test_org.id),
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
@@ -353,11 +320,25 @@ async def test_success_callback_success(async_session_maker, test_org, test_user
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
) as mock_update_budget,
):
mock_db_session = MagicMock()
# First query: BillingSession (query().filter().filter().first())
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
# Second query: Org (query().filter().first()) - use side_effect for different return chains
mock_query_chain_billing = MagicMock()
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_query_chain_org = MagicMock()
mock_query_chain_org.filter.return_value.first.return_value = mock_org
mock_db_session.query.side_effect = [
mock_query_chain_billing,
mock_query_chain_org,
]
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(
status='complete', amount_subtotal=2500, customer='mock_customer_id'
)
) # $25.00 in cents
response = await success_callback(session_id, mock_request)
response = await success_callback('test_session_id', mock_request)
assert response.status_code == 302
assert (
@@ -365,80 +346,64 @@ async def test_success_callback_success(async_session_maker, test_org, test_user
== 'https://test.com/settings/billing?checkout=success'
)
# Verify LiteLLM API calls
mock_update_budget.assert_called_once_with(
str(test_org.id),
'mock_org_id',
125.0, # 100 + 25.00
)
# Verify database updates
async with async_session_maker() as session:
result = await session.execute(
select(BillingSession).where(BillingSession.id == session_id)
)
billing_session = result.scalar_one_or_none()
assert billing_session.status == 'completed'
assert float(billing_session.price) == 25.0
# Verify BYOR export is enabled for the org (updated in same session)
assert mock_org.byor_export_enabled is True
# Verify org byor_export_enabled was set
org_result = await session.execute(select(Org).where(Org.id == test_org.id))
org = org_result.scalar_one_or_none()
assert org.byor_export_enabled is True
# Verify database updates
assert mock_billing_session.status == 'completed'
assert mock_billing_session.price == 25.0
mock_db_session.merge.assert_called_once()
mock_db_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_success_callback_lite_llm_error(
async_session_maker, test_org, test_user
):
async def test_success_callback_lite_llm_error():
"""Test handling of LiteLLM API errors during success callback."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
session_id = 'test_litellm_error_session'
async with async_session_maker() as session:
billing_session = BillingSession(
id=session_id,
user_id=str(test_user.id),
org_id=test_org.id,
status='in_progress',
price=25,
price_code='NA',
)
session.add(billing_session)
await session.commit()
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
with (
patch('server.routes.billing.a_session_maker', async_session_maker),
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch(
'storage.user_store.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id=test_org.id),
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
side_effect=Exception('LiteLLM API Error'),
),
):
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(
status='complete', amount_subtotal=2500
)
with pytest.raises(Exception, match='LiteLLM API Error'):
await success_callback(session_id, mock_request)
await success_callback('test_session_id', mock_request)
# Verify no database updates occurred (transaction rolled back)
async with async_session_maker() as session:
result = await session.execute(
select(BillingSession).where(BillingSession.id == session_id)
)
billing_session = result.scalar_one_or_none()
assert billing_session.status == 'in_progress'
# Verify no database updates occurred
assert mock_billing_session.status == 'in_progress'
mock_db_session.merge.assert_not_called()
mock_db_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_success_callback_lite_llm_update_budget_error_rollback(
async_session_maker, test_org, test_user
):
async def test_success_callback_lite_llm_update_budget_error_rollback():
"""Test that database changes are not committed when update_team_and_users_budget fails.
This test verifies that if LiteLlmManager.update_team_and_users_budget raises an exception,
@@ -447,26 +412,19 @@ async def test_success_callback_lite_llm_update_budget_error_rollback(
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
session_id = 'test_budget_rollback_session'
async with async_session_maker() as session:
billing_session = BillingSession(
id=session_id,
user_id=str(test_user.id),
org_id=test_org.id,
status='in_progress',
price=10,
price_code='NA',
)
session.add(billing_session)
await session.commit()
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_org = MagicMock()
with (
patch('server.routes.billing.a_session_maker', async_session_maker),
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch(
'storage.user_store.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id=test_org.id),
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
@@ -480,60 +438,70 @@ async def test_success_callback_lite_llm_update_budget_error_rollback(
side_effect=Exception('LiteLLM API Error'),
),
):
mock_db_session = MagicMock()
mock_query_chain_billing = MagicMock()
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_query_chain_org = MagicMock()
mock_query_chain_org.filter.return_value.first.return_value = mock_org
mock_db_session.query.side_effect = [
mock_query_chain_billing,
mock_query_chain_org,
]
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(
status='complete',
amount_subtotal=1000,
amount_subtotal=1000, # $10
customer='mock_customer_id',
)
with pytest.raises(Exception, match='LiteLLM API Error'):
await success_callback(session_id, mock_request)
await success_callback('test_session_id', mock_request)
# Verify no database commit occurred - the transaction should roll back
async with async_session_maker() as session:
result = await session.execute(
select(BillingSession).where(BillingSession.id == session_id)
)
billing_session = result.scalar_one_or_none()
assert billing_session.status == 'in_progress'
# Verify no database commit occurred - the transaction should roll back
assert mock_billing_session.status == 'in_progress'
mock_db_session.merge.assert_not_called()
mock_db_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_cancel_callback_session_not_found(async_session_maker):
async def test_cancel_callback_session_not_found():
"""Test cancel callback when billing session is not found."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
with patch('server.routes.billing.a_session_maker', async_session_maker):
response = await cancel_callback('nonexistent_session_id', mock_request)
with patch('server.routes.billing.session_maker') as mock_session_maker:
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = None
mock_session_maker.return_value.__enter__.return_value = mock_db_session
response = await cancel_callback('test_session_id', mock_request)
assert response.status_code == 302
assert (
response.headers['location']
== 'https://test.com/settings/billing?checkout=cancel'
)
# Verify no database updates occurred
mock_db_session.merge.assert_not_called()
mock_db_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_cancel_callback_success(async_session_maker, test_org, test_user):
async def test_cancel_callback_success():
"""Test successful cancellation of billing session."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
session_id = 'test_cancel_session'
async with async_session_maker() as session:
billing_session = BillingSession(
id=session_id,
user_id=str(test_user.id),
org_id=test_org.id,
status='in_progress',
price=25,
price_code='NA',
)
session.add(billing_session)
await session.commit()
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
with patch('server.routes.billing.a_session_maker', async_session_maker):
response = await cancel_callback(session_id, mock_request)
with patch('server.routes.billing.session_maker') as mock_session_maker:
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_session_maker.return_value.__enter__.return_value = mock_db_session
response = await cancel_callback('test_session_id', mock_request)
assert response.status_code == 302
assert (
@@ -541,18 +509,16 @@ async def test_cancel_callback_success(async_session_maker, test_org, test_user)
== 'https://test.com/settings/billing?checkout=cancel'
)
# Verify database update
async with async_session_maker() as session:
result = await session.execute(
select(BillingSession).where(BillingSession.id == session_id)
)
billing_session = result.scalar_one_or_none()
assert billing_session.status == 'cancelled'
# Verify database updates
assert mock_billing_session.status == 'cancelled'
mock_db_session.merge.assert_called_once()
mock_db_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_has_payment_method_with_payment_method():
"""Test has_payment_method returns True when user has a payment method."""
mock_has_payment_method = AsyncMock(return_value=True)
with patch(
'server.routes.billing.stripe_service.has_payment_method_by_user_id',

View File

@@ -1,6 +1,6 @@
"""Unit tests for DomainBlocker class."""
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import MagicMock
import pytest
from server.auth.domain_blocker import DomainBlocker
@@ -9,9 +9,7 @@ from server.auth.domain_blocker import DomainBlocker
@pytest.fixture
def mock_store():
"""Create a mock BlockedEmailDomainStore for testing."""
store = MagicMock()
store.is_domain_blocked = AsyncMock()
return store
return MagicMock()
@pytest.fixture
@@ -59,120 +57,109 @@ def test_extract_domain_invalid_emails(domain_blocker, email, expected):
assert result == expected
@pytest.mark.asyncio
async def test_is_domain_blocked_with_none_email(domain_blocker, mock_store):
def test_is_domain_blocked_with_none_email(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when email is None."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked(None)
result = domain_blocker.is_domain_blocked(None)
# Assert
assert result is False
mock_store.is_domain_blocked.assert_not_called()
@pytest.mark.asyncio
async def test_is_domain_blocked_with_empty_email(domain_blocker, mock_store):
def test_is_domain_blocked_with_empty_email(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when email is empty."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('')
result = domain_blocker.is_domain_blocked('')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_not_called()
@pytest.mark.asyncio
async def test_is_domain_blocked_with_invalid_email(domain_blocker, mock_store):
def test_is_domain_blocked_with_invalid_email(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when email format is invalid."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('invalid-email')
result = domain_blocker.is_domain_blocked('invalid-email')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_not_called()
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_not_blocked(domain_blocker, mock_store):
def test_is_domain_blocked_domain_not_blocked(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when domain is not blocked."""
# Arrange
mock_store.is_domain_blocked.return_value = False
# Act
result = await domain_blocker.is_domain_blocked('user@example.com')
result = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('example.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_blocked(domain_blocker, mock_store):
def test_is_domain_blocked_domain_blocked(domain_blocker, mock_store):
"""Test that is_domain_blocked returns True when domain is blocked."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@colsch.us')
result = domain_blocker.is_domain_blocked('user@colsch.us')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_case_insensitive(domain_blocker, mock_store):
def test_is_domain_blocked_case_insensitive(domain_blocker, mock_store):
"""Test that is_domain_blocked performs case-insensitive domain extraction."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@COLSCH.US')
result = domain_blocker.is_domain_blocked('user@COLSCH.US')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_with_whitespace(domain_blocker, mock_store):
def test_is_domain_blocked_with_whitespace(domain_blocker, mock_store):
"""Test that is_domain_blocked handles emails with whitespace correctly."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked(' user@colsch.us ')
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_multiple_blocked_domains(domain_blocker, mock_store):
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker, mock_store):
"""Test that is_domain_blocked correctly checks multiple domains."""
# Arrange
mock_store.is_domain_blocked = AsyncMock(
side_effect=lambda domain: domain
in [
'other-domain.com',
'blocked.org',
]
)
mock_store.is_domain_blocked.side_effect = lambda domain: domain in [
'other-domain.com',
'blocked.org',
]
# Act
result1 = await domain_blocker.is_domain_blocked('user@other-domain.com')
result2 = await domain_blocker.is_domain_blocked('user@blocked.org')
result3 = await domain_blocker.is_domain_blocked('user@allowed.com')
result1 = domain_blocker.is_domain_blocked('user@other-domain.com')
result2 = domain_blocker.is_domain_blocked('user@blocked.org')
result3 = domain_blocker.is_domain_blocked('user@allowed.com')
# Assert
assert result1 is True
@@ -181,8 +168,7 @@ async def test_is_domain_blocked_multiple_blocked_domains(domain_blocker, mock_s
assert mock_store.is_domain_blocked.call_count == 3
@pytest.mark.asyncio
async def test_is_domain_blocked_tld_pattern_blocks_matching_domain(
def test_is_domain_blocked_tld_pattern_blocks_matching_domain(
domain_blocker, mock_store
):
"""Test that TLD pattern blocks domains ending with that TLD."""
@@ -190,15 +176,14 @@ async def test_is_domain_blocked_tld_pattern_blocks_matching_domain(
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@company.us')
result = domain_blocker.is_domain_blocked('user@company.us')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('company.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(
def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(
domain_blocker, mock_store
):
"""Test that TLD pattern blocks subdomains with that TLD."""
@@ -206,15 +191,14 @@ async def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@subdomain.company.us')
result = domain_blocker.is_domain_blocked('user@subdomain.company.us')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('subdomain.company.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(
def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(
domain_blocker, mock_store
):
"""Test that TLD pattern does not block domains with different TLD."""
@@ -222,41 +206,35 @@ async def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(
mock_store.is_domain_blocked.return_value = False
# Act
result = await domain_blocker.is_domain_blocked('user@company.com')
result = domain_blocker.is_domain_blocked('user@company.com')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('company.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_tld_pattern_case_insensitive(
domain_blocker, mock_store
):
def test_is_domain_blocked_tld_pattern_case_insensitive(domain_blocker, mock_store):
"""Test that TLD pattern matching is case-insensitive."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@COMPANY.US')
result = domain_blocker.is_domain_blocked('user@COMPANY.US')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('company.us')
@pytest.mark.asyncio
async def test_is_domain_blocked_tld_pattern_with_multi_level_tld(
domain_blocker, mock_store
):
def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker, mock_store):
"""Test that TLD pattern works with multi-level TLDs like .co.uk."""
# Arrange
mock_store.is_domain_blocked.side_effect = lambda domain: domain.endswith('.co.uk')
# Act
result_match = await domain_blocker.is_domain_blocked('user@example.co.uk')
result_subdomain = await domain_blocker.is_domain_blocked('user@api.example.co.uk')
result_no_match = await domain_blocker.is_domain_blocked('user@example.uk')
result_match = domain_blocker.is_domain_blocked('user@example.co.uk')
result_subdomain = domain_blocker.is_domain_blocked('user@api.example.co.uk')
result_no_match = domain_blocker.is_domain_blocked('user@example.uk')
# Assert
assert result_match is True
@@ -264,8 +242,7 @@ async def test_is_domain_blocked_tld_pattern_with_multi_level_tld(
assert result_no_match is False
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_pattern_blocks_exact_match(
def test_is_domain_blocked_domain_pattern_blocks_exact_match(
domain_blocker, mock_store
):
"""Test that domain pattern blocks exact domain match."""
@@ -273,31 +250,27 @@ async def test_is_domain_blocked_domain_pattern_blocks_exact_match(
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@example.com')
result = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('example.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_pattern_blocks_subdomain(
domain_blocker, mock_store
):
def test_is_domain_blocked_domain_pattern_blocks_subdomain(domain_blocker, mock_store):
"""Test that domain pattern blocks subdomains of that domain."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@subdomain.example.com')
result = domain_blocker.is_domain_blocked('user@subdomain.example.com')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('subdomain.example.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain(
def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain(
domain_blocker, mock_store
):
"""Test that domain pattern blocks multi-level subdomains."""
@@ -305,15 +278,14 @@ async def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain(
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked('user@api.v2.example.com')
result = domain_blocker.is_domain_blocked('user@api.v2.example.com')
# Assert
assert result is True
mock_store.is_domain_blocked.assert_called_once_with('api.v2.example.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain(
def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain(
domain_blocker, mock_store
):
"""Test that domain pattern does not block domains that contain but don't match the pattern."""
@@ -321,15 +293,14 @@ async def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain(
mock_store.is_domain_blocked.return_value = False
# Act
result = await domain_blocker.is_domain_blocked('user@notexample.com')
result = domain_blocker.is_domain_blocked('user@notexample.com')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('notexample.com')
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_pattern_does_not_block_different_tld(
def test_is_domain_blocked_domain_pattern_does_not_block_different_tld(
domain_blocker, mock_store
):
"""Test that domain pattern does not block same domain with different TLD."""
@@ -337,15 +308,14 @@ async def test_is_domain_blocked_domain_pattern_does_not_block_different_tld(
mock_store.is_domain_blocked.return_value = False
# Act
result = await domain_blocker.is_domain_blocked('user@example.org')
result = domain_blocker.is_domain_blocked('user@example.org')
# Assert
assert result is False
mock_store.is_domain_blocked.assert_called_once_with('example.org')
@pytest.mark.asyncio
async def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
domain_blocker, mock_store
):
"""Test that blocking a subdomain also blocks its nested subdomains."""
@@ -355,9 +325,9 @@ async def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
)
# Act
result_exact = await domain_blocker.is_domain_blocked('user@api.example.com')
result_nested = await domain_blocker.is_domain_blocked('user@v1.api.example.com')
result_parent = await domain_blocker.is_domain_blocked('user@example.com')
result_exact = domain_blocker.is_domain_blocked('user@api.example.com')
result_nested = domain_blocker.is_domain_blocked('user@v1.api.example.com')
result_parent = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result_exact is True
@@ -365,15 +335,14 @@ async def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
assert result_parent is False
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_with_hyphens(domain_blocker, mock_store):
def test_is_domain_blocked_domain_with_hyphens(domain_blocker, mock_store):
"""Test that domain patterns work with hyphenated domains."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result_exact = await domain_blocker.is_domain_blocked('user@my-company.com')
result_subdomain = await domain_blocker.is_domain_blocked('user@api.my-company.com')
result_exact = domain_blocker.is_domain_blocked('user@my-company.com')
result_subdomain = domain_blocker.is_domain_blocked('user@api.my-company.com')
# Assert
assert result_exact is True
@@ -381,15 +350,14 @@ async def test_is_domain_blocked_domain_with_hyphens(domain_blocker, mock_store)
assert mock_store.is_domain_blocked.call_count == 2
@pytest.mark.asyncio
async def test_is_domain_blocked_domain_with_numbers(domain_blocker, mock_store):
def test_is_domain_blocked_domain_with_numbers(domain_blocker, mock_store):
"""Test that domain patterns work with numeric domains."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result_exact = await domain_blocker.is_domain_blocked('user@test123.com')
result_subdomain = await domain_blocker.is_domain_blocked('user@api.test123.com')
result_exact = domain_blocker.is_domain_blocked('user@test123.com')
result_subdomain = domain_blocker.is_domain_blocked('user@api.test123.com')
# Assert
assert result_exact is True
@@ -397,14 +365,13 @@ async def test_is_domain_blocked_domain_with_numbers(domain_blocker, mock_store)
assert mock_store.is_domain_blocked.call_count == 2
@pytest.mark.asyncio
async def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker, mock_store):
def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker, mock_store):
"""Test that blocking works with very long subdomain chains."""
# Arrange
mock_store.is_domain_blocked.return_value = True
# Act
result = await domain_blocker.is_domain_blocked(
result = domain_blocker.is_domain_blocked(
'user@level4.level3.level2.level1.example.com'
)
@@ -415,14 +382,13 @@ async def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker, mock_
)
@pytest.mark.asyncio
async def test_is_domain_blocked_handles_store_exception(domain_blocker, mock_store):
def test_is_domain_blocked_handles_store_exception(domain_blocker, mock_store):
"""Test that is_domain_blocked returns False when store raises an exception."""
# Arrange
mock_store.is_domain_blocked.side_effect = Exception('Database connection error')
# Act
result = await domain_blocker.is_domain_blocked('user@example.com')
result = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is False

View File

@@ -1,6 +1,5 @@
import sys
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch
import pytest
from fastapi import HTTPException
@@ -28,7 +27,6 @@ async def test_submit_feedback():
"""Test submitting feedback for a conversation."""
# Create a mock database session
mock_session = MagicMock()
mock_session.commit = AsyncMock()
# Test data
feedback_data = FeedbackRequest(
@@ -39,13 +37,19 @@ async def test_submit_feedback():
metadata={'browser': 'Chrome', 'os': 'Windows'},
)
# Create async context manager for a_session_maker
@asynccontextmanager
async def mock_a_session_maker():
yield mock_session
# Mock session_maker and call_sync_from_async
with patch('server.routes.feedback.session_maker') as mock_session_maker, patch(
'server.routes.feedback.call_sync_from_async'
) as mock_call_sync:
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
# Mock call_sync_from_async to execute the function
def mock_call_sync_side_effect(func):
return func()
mock_call_sync.side_effect = mock_call_sync_side_effect
# Mock a_session_maker
with patch('server.routes.feedback.a_session_maker', mock_a_session_maker):
# Call the function
result = await submit_conversation_feedback(feedback_data)
@@ -74,7 +78,6 @@ async def test_invalid_rating():
"""Test submitting feedback with an invalid rating."""
# Create a mock database session
mock_session = MagicMock()
mock_session.commit = AsyncMock()
# Since Pydantic validation happens before our function is called,
# we need to patch the validation to test our function's validation
@@ -92,13 +95,14 @@ async def test_invalid_rating():
# Mock the validation to return our object
mock_validate.return_value = feedback_data
# Create async context manager for a_session_maker
@asynccontextmanager
async def mock_a_session_maker():
yield mock_session
# Mock session_maker and call_sync_from_async
with patch('server.routes.feedback.session_maker') as mock_session_maker, patch(
'server.routes.feedback.call_sync_from_async'
) as mock_call_sync:
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
mock_call_sync.return_value = None
# Mock a_session_maker
with patch('server.routes.feedback.a_session_maker', mock_a_session_maker):
# Call the function and expect an exception
with pytest.raises(HTTPException) as excinfo:
await submit_conversation_feedback(feedback_data)

View File

@@ -2,7 +2,7 @@
Tests for the GitlabCallbackProcessor.
"""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from integrations.gitlab.gitlab_view import GitlabIssueComment
@@ -111,15 +111,20 @@ class TestGitlabCallbackProcessor:
@patch(
'server.conversation_callback_processor.gitlab_callback_processor.conversation_manager'
)
@patch(
'server.conversation_callback_processor.gitlab_callback_processor.session_maker'
)
async def test_call_with_send_summary_instruction(
self,
mock_session_maker,
mock_conversation_manager,
mock_get_summary_instruction,
async_session_maker,
gitlab_callback_processor,
):
"""Test the __call__ method when send_summary_instruction is True."""
# Setup mocks
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_conversation_manager.send_event_to_conversation = AsyncMock()
mock_get_summary_instruction.return_value = (
"I'm a man of few words. Any questions?"
@@ -137,17 +142,15 @@ class TestGitlabCallbackProcessor:
)
# Call the processor
with patch(
'server.conversation_callback_processor.gitlab_callback_processor.a_session_maker',
async_session_maker,
):
await gitlab_callback_processor(callback, observation)
await gitlab_callback_processor(callback, observation)
# Verify that send_event_to_conversation was called
mock_conversation_manager.send_event_to_conversation.assert_called_once()
# Verify that the processor state was updated
assert gitlab_callback_processor.send_summary_instruction is False
mock_session.merge.assert_called_once_with(callback)
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
@patch(
@@ -159,16 +162,21 @@ class TestGitlabCallbackProcessor:
@patch(
'server.conversation_callback_processor.gitlab_callback_processor.asyncio.create_task'
)
@patch(
'server.conversation_callback_processor.gitlab_callback_processor.session_maker'
)
async def test_call_with_extract_summary(
self,
mock_session_maker,
mock_create_task,
mock_extract_summary,
mock_conversation_manager,
async_session_maker,
gitlab_callback_processor,
):
"""Test the __call__ method when send_summary_instruction is False."""
# Setup mocks
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_extract_summary.return_value = 'Test summary'
# Ensure we don't leak an un-awaited coroutine when create_task is mocked
mock_create_task.side_effect = lambda coro: (coro.close(), None)[1]
@@ -188,22 +196,20 @@ class TestGitlabCallbackProcessor:
)
# Call the processor
with patch(
'server.conversation_callback_processor.gitlab_callback_processor.a_session_maker',
async_session_maker,
):
await gitlab_callback_processor(callback, observation)
await gitlab_callback_processor(callback, observation)
# Verify that extract_summary_from_conversation_manager was called
mock_extract_summary.assert_called_once_with(
mock_conversation_manager, 'conv123'
)
# Verify that create_task was called at least once to send the message
assert mock_create_task.call_count >= 1
# Verify that create_task was called to send the message
mock_create_task.assert_called_once()
# Verify that the callback status was updated
assert callback.status == CallbackStatus.COMPLETED
mock_session.merge.assert_called_once_with(callback)
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_call_with_non_terminal_state(self, gitlab_callback_processor):

View File

@@ -1,54 +1,56 @@
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import select
from server.auth.token_manager import TokenManager
from storage.offline_token_store import OfflineTokenStore
from storage.stored_offline_token import StoredOfflineToken
from openhands.core.config.openhands_config import OpenHandsConfig
@pytest.fixture
def mock_config():
return None # Not used in tests
return MagicMock(spec=OpenHandsConfig)
@pytest.fixture
def token_store(session_maker, mock_config):
return OfflineTokenStore('test_user_id', session_maker, mock_config)
@pytest.fixture
def token_manager():
with patch('server.config.get_config') as mock_get_config:
mock_config = mock_get_config.return_value
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
return TokenManager(external=False)
@pytest.mark.asyncio
async def test_store_token_new_record(async_session_maker, mock_config):
# Setup - inject the test session maker into the store module
import storage.offline_token_store as store_module
store_module.a_session_maker = async_session_maker
token_store = OfflineTokenStore('test_user_id', mock_config)
async def test_store_token_new_record(token_store, session_maker):
# Setup
test_token = 'test_offline_token'
# Execute
await token_store.store_token(test_token)
# Verify - use a new session to query
async with async_session_maker() as session:
result = await session.execute(
select(StoredOfflineToken).where(
StoredOfflineToken.user_id == 'test_user_id'
)
)
record = result.scalar_one_or_none()
assert record is not None
assert record.user_id == 'test_user_id'
assert record.offline_token == test_token
# Verify
with session_maker() as session:
query = session.query(StoredOfflineToken)
assert query.count() == 1
added_record = query.first()
assert added_record.user_id == 'test_user_id'
assert added_record.offline_token == test_token
@pytest.mark.asyncio
async def test_store_token_existing_record(async_session_maker, mock_config):
# Setup - inject the test session maker into the store module
import storage.offline_token_store as store_module
store_module.a_session_maker = async_session_maker
token_store = OfflineTokenStore('test_user_id', mock_config)
async with async_session_maker() as session:
async def test_store_token_existing_record(token_store, session_maker):
# Setup
with session_maker() as session:
session.add(
StoredOfflineToken(user_id='test_user_id', offline_token='old_token')
)
await session.commit()
session.commit()
test_token = 'new_offline_token'
@@ -56,35 +58,24 @@ async def test_store_token_existing_record(async_session_maker, mock_config):
await token_store.store_token(test_token)
# Verify
async with async_session_maker() as session:
from sqlalchemy import select
result = await session.execute(
select(StoredOfflineToken).where(
StoredOfflineToken.user_id == 'test_user_id'
)
)
record = result.scalar_one_or_none()
assert record is not None
assert record.offline_token == test_token
with session_maker() as session:
query = session.query(StoredOfflineToken)
assert query.count() == 1
added_record = query.first()
assert added_record.user_id == 'test_user_id'
assert added_record.offline_token == test_token
@pytest.mark.asyncio
async def test_load_token_existing(async_session_maker, mock_config):
# Setup - inject the test session maker into the store module
import storage.offline_token_store as store_module
store_module.a_session_maker = async_session_maker
token_store = OfflineTokenStore('test_user_id', mock_config)
async with async_session_maker() as session:
async def test_load_token_existing(token_store, session_maker):
# Setup
with session_maker() as session:
session.add(
StoredOfflineToken(
user_id='test_user_id', offline_token='test_offline_token'
)
)
await session.commit()
session.commit()
# Execute
result = await token_store.load_token()
@@ -94,14 +85,7 @@ async def test_load_token_existing(async_session_maker, mock_config):
@pytest.mark.asyncio
async def test_load_token_not_found(async_session_maker, mock_config):
# Setup - inject the test session maker into the store module
import storage.offline_token_store as store_module
store_module.a_session_maker = async_session_maker
token_store = OfflineTokenStore('nonexistent_user', mock_config)
async def test_load_token_not_found(token_store):
# Execute
result = await token_store.load_token()
@@ -120,3 +104,10 @@ async def test_get_instance(mock_config):
# Verify
assert isinstance(result, OfflineTokenStore)
assert result.user_id == test_user_id
assert result.config == mock_config
def test_load_store_org_token(token_manager, session_maker):
with patch('server.auth.token_manager.session_maker', session_maker):
token_manager.store_org_token('some-org-id', 'some-token')
assert token_manager.load_org_token('some-org-id') == 'some-token'

View File

@@ -4,12 +4,17 @@ from unittest.mock import patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.org import Org
from storage.org_member import OrgMember
from storage.org_member_store import OrgMemberStore
from storage.role import Role
from storage.user import User
# Mock the database module before importing OrgMemberStore
with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from storage.base import Base
from storage.org import Org
from storage.org_member import OrgMember
from storage.org_member_store import OrgMemberStore
from storage.role import Role
from storage.user import User
@pytest.fixture

View File

@@ -9,18 +9,23 @@ import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from server.routes.org_models import (
LiteLLMIntegrationError,
OrgAuthorizationError,
OrgDatabaseError,
OrgNameExistsError,
OrgNotFoundError,
)
from storage.org import Org
from storage.org_member import OrgMember
from storage.org_service import OrgService
from storage.role import Role
from storage.user import User
# Mock the database module before importing OrgService
with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from server.routes.org_models import (
LiteLLMIntegrationError,
OrgAuthorizationError,
OrgDatabaseError,
OrgNameExistsError,
OrgNotFoundError,
)
from storage.org import Org
from storage.org_member import OrgMember
from storage.org_service import OrgService
from storage.role import Role
from storage.user import User
@pytest.fixture

View File

@@ -5,12 +5,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from sqlalchemy.exc import IntegrityError
from storage.org import Org
from storage.org_invitation import OrgInvitation
from storage.org_member import OrgMember
from storage.org_store import OrgStore
from storage.role import Role
from storage.user import User
# Mock the database module before importing OrgStore
with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from storage.org import Org
from storage.org_invitation import OrgInvitation
from storage.org_member import OrgMember
from storage.org_store import OrgStore
from storage.role import Role
from storage.user import User
from openhands.storage.data_models.settings import Settings

View File

@@ -1,8 +1,13 @@
from unittest.mock import MagicMock, patch
import pytest
from integrations.github.github_view import get_user_proactive_conversation_setting
from storage.org import Org
# Mock the database module before importing
with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from integrations.github.github_view import get_user_proactive_conversation_setting
from storage.org import Org
pytestmark = pytest.mark.asyncio

View File

@@ -1,147 +0,0 @@
from unittest.mock import patch
import pytest
from sqlalchemy import select
from storage.repository_store import RepositoryStore
from storage.stored_repository import StoredRepository
@pytest.fixture
def repository_store():
return RepositoryStore(config=None)
@pytest.mark.asyncio
async def test_store_projects_empty_list(repository_store, async_session_maker):
"""Test storing empty list of repositories."""
with patch(
'storage.repository_store.RepositoryStore.store_projects'
) as mock_method:
# Should handle empty list gracefully
mock_method.return_value = None
# Test that we handle empty repositories
result = await repository_store.store_projects([])
# The method should return early for empty list
assert result is None
@pytest.mark.asyncio
async def test_store_projects_new_repositories(repository_store, async_session_maker):
"""Test storing new repositories in the database."""
# Setup - create repositories
repo1 = StoredRepository(
repo_name='owner/repo1',
repo_id='github##123',
is_public=False,
)
repo2 = StoredRepository(
repo_name='owner/repo2',
repo_id='github##456',
is_public=True,
)
# Execute - patch a_session_maker to use test's async session maker
with patch('storage.repository_store.a_session_maker', async_session_maker):
await repository_store.store_projects([repo1, repo2])
# Verify the repositories were stored
async with async_session_maker() as session:
result = await session.execute(
select(StoredRepository).filter(
StoredRepository.repo_id.in_(['github##123', 'github##456'])
)
)
repos = result.scalars().all()
assert len(repos) == 2
repo_ids = {r.repo_id for r in repos}
assert 'github##123' in repo_ids
assert 'github##456' in repo_ids
@pytest.mark.asyncio
async def test_store_projects_update_existing(repository_store, async_session_maker):
"""Test updating existing repositories in the database."""
# Setup - create existing repository
existing_repo = StoredRepository(
repo_name='owner/repo1',
repo_id='github##123',
is_public=True,
)
async with async_session_maker() as session:
session.add(existing_repo)
await session.commit()
# Execute - update the repository with new values
updated_repo = StoredRepository(
repo_name='owner/repo1-updated',
repo_id='github##123',
is_public=False, # Changed from True
)
with patch('storage.repository_store.a_session_maker', async_session_maker):
await repository_store.store_projects([updated_repo])
# Verify the repository was updated
async with async_session_maker() as session:
result = await session.execute(
select(StoredRepository).filter(StoredRepository.repo_id == 'github##123')
)
repo = result.scalars().first()
assert repo is not None
assert repo.repo_name == 'owner/repo1-updated'
assert repo.is_public is False
@pytest.mark.asyncio
async def test_store_projects_mixed_new_and_existing(
repository_store, async_session_maker
):
"""Test storing a mix of new and existing repositories."""
# Setup - create one existing repository
existing_repo = StoredRepository(
repo_name='owner/existing-repo',
repo_id='github##123',
is_public=True,
)
async with async_session_maker() as session:
session.add(existing_repo)
await session.commit()
# Execute - store a mix of new and existing
repos_to_store = [
StoredRepository(
repo_name='owner/existing-repo',
repo_id='github##123',
is_public=False, # Will update
),
StoredRepository(
repo_name='owner/new-repo',
repo_id='github##456',
is_public=True,
),
]
with patch('storage.repository_store.a_session_maker', async_session_maker):
await repository_store.store_projects(repos_to_store)
# Verify results
async with async_session_maker() as session:
result = await session.execute(
select(StoredRepository).filter(
StoredRepository.repo_id.in_(['github##123', 'github##456'])
)
)
repos = result.scalars().all()
assert len(repos) == 2
# Check the updated existing repo
existing = next(r for r in repos if r.repo_id == 'github##123')
assert existing.repo_name == 'owner/existing-repo'
assert existing.is_public is False
# Check the new repo
new = next(r for r in repos if r.repo_id == 'github##456')
assert new.repo_name == 'owner/new-repo'
assert new.is_public is True

View File

@@ -1,14 +1,16 @@
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from storage.saas_conversation_store import SaasConversationStore
from storage.user import User
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
# Mock the database module before importing
with patch('storage.database.engine'), patch('storage.database.a_engine'):
from storage.saas_conversation_store import SaasConversationStore
from storage.user import User
@pytest.fixture(autouse=True)
def mock_call_sync_from_async():
@@ -164,53 +166,3 @@ async def test_exists(session_maker):
assert not await store.exists('exists-test')
await store.save_metadata(metadata)
assert await store.exists('exists-test')
class TestGetInstance:
"""Tests for SaasConversationStore.get_instance method.
The get_instance method uses async UserStore.get_user_by_id_async because
callers now use asyncio.run_coroutine_threadsafe() to dispatch to the main
event loop where asyncpg connections work properly.
"""
@pytest.mark.asyncio
async def test_get_instance_uses_async_get_user_by_id(self):
"""Verify get_instance calls the async get_user_by_id_async for proper event loop handling."""
# Arrange
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
mock_user = MagicMock(spec=User)
mock_user.current_org_id = UUID(user_id)
mock_config = MagicMock(spec=OpenHandsConfig)
with patch(
'storage.saas_conversation_store.UserStore.get_user_by_id_async',
AsyncMock(return_value=mock_user),
) as mock_async_get_user, patch(
'storage.saas_conversation_store.session_maker'
):
# Act
store = await SaasConversationStore.get_instance(mock_config, user_id)
# Assert
mock_async_get_user.assert_called_once_with(user_id)
assert store.user_id == user_id
assert store.org_id == mock_user.current_org_id
@pytest.mark.asyncio
async def test_get_instance_handles_none_user(self):
"""Verify get_instance handles case when user is not found."""
# Arrange
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
mock_config = MagicMock(spec=OpenHandsConfig)
with patch(
'storage.saas_conversation_store.UserStore.get_user_by_id_async',
AsyncMock(return_value=None),
), patch('storage.saas_conversation_store.session_maker'):
# Act
store = await SaasConversationStore.get_instance(mock_config, user_id)
# Assert
assert store.user_id == user_id
assert store.org_id is None

View File

@@ -29,16 +29,8 @@ def mock_user():
@pytest.fixture
def secrets_store(async_session_maker, mock_config):
# Inject the test session maker into the store module
import storage.saas_secrets_store as store_module
store_module.a_session_maker = async_session_maker
store = SaasSecretsStore('user-id', mock_config)
# Also add it as an attribute for tests that need direct access
store.a_session_maker = async_session_maker
return store
def secrets_store(session_maker, mock_config):
return SaasSecretsStore('user-id', session_maker, mock_config)
class TestSaasSecretsStore:
@@ -115,15 +107,13 @@ class TestSaasSecretsStore:
await secrets_store.store(user_secrets)
# Verify the data is encrypted in the database
from sqlalchemy import select
async with secrets_store.a_session_maker() as session:
result = await session.execute(
select(StoredCustomSecrets)
with secrets_store.session_maker() as session:
stored = (
session.query(StoredCustomSecrets)
.filter(StoredCustomSecrets.keycloak_user_id == 'user-id')
.filter(StoredCustomSecrets.org_id == mock_user.current_org_id)
.first()
)
stored = result.scalars().first()
# The sensitive data should be encrypted
assert stored.secret_value != 'sensitive_token'

View File

@@ -8,7 +8,7 @@ from openhands.server.settings import Settings
from openhands.storage.data_models.settings import Settings as DataSettings
# Mock the database module before importing
with patch('storage.database.a_session_maker'):
with patch('storage.database.engine'), patch('storage.database.a_engine'):
from server.constants import (
LITE_LLM_API_URL,
)
@@ -26,21 +26,19 @@ def mock_config():
@pytest.fixture
def settings_store(async_session_maker, mock_config):
store = SaasSettingsStore('5594c7b6-f959-4b81-92e9-b09c206f5081', mock_config)
store.a_session_maker = async_session_maker
def settings_store(session_maker, mock_config):
store = SaasSettingsStore(
'5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker, mock_config
)
# Patch the load method to read from UserSettings table directly (for testing)
async def patched_load():
async with store.a_session_maker() as session:
from sqlalchemy import select
result = await session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == store.user_id
)
with store.session_maker() as session:
user_settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == store.user_id)
.first()
)
user_settings = result.scalars().first()
if not user_settings:
# Return default settings
return Settings(
@@ -76,31 +74,29 @@ def settings_store(async_session_maker, mock_config):
if 'secrets_store' in item_dict:
del item_dict['secrets_store']
# Encrypt the data before storing
store._encrypt_kwargs(item_dict)
# Continue with the original implementation
from sqlalchemy import select
async with store.a_session_maker() as session:
result = await session.execute(
select(UserSettings).filter(
with store.session_maker() as session:
existing = None
if item_dict:
store._encrypt_kwargs(item_dict)
query = session.query(UserSettings).filter(
UserSettings.keycloak_user_id == store.user_id
)
)
existing = result.scalars().first()
# First check if we have an existing entry in the new table
existing = query.first()
if existing:
# Update existing entry
for key, value in item_dict.items():
if key in existing.__class__.__table__.columns:
setattr(existing, key, value)
await session.merge(existing)
session.merge(existing)
else:
item_dict['keycloak_user_id'] = store.user_id
settings = UserSettings(**item_dict)
session.add(settings)
await session.commit()
session.commit()
# Replace the methods with our patched versions
store.store = patched_store
@@ -129,26 +125,25 @@ async def test_store_and_load_keycloak_user(settings_store):
assert loaded_settings.agent == 'smith'
# Verify it was stored in user_settings table with keycloak_user_id
from sqlalchemy import select
async with settings_store.a_session_maker() as session:
result = await session.execute(
select(UserSettings).filter(
with settings_store.session_maker() as session:
stored = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id == '550e8400-e29b-41d4-a716-446655440000'
)
.first()
)
stored = result.scalars().first()
assert stored is not None
assert stored.agent == 'smith'
@pytest.mark.asyncio
async def test_load_returns_default_when_not_found(settings_store, async_session_maker):
async def test_load_returns_default_when_not_found(settings_store, session_maker):
file_store = MagicMock()
file_store.read.side_effect = FileNotFoundError()
with (
patch('storage.saas_settings_store.a_session_maker', async_session_maker),
patch('storage.saas_settings_store.session_maker', session_maker),
):
loaded_settings = await settings_store.load()
assert loaded_settings is not None
@@ -169,15 +164,14 @@ async def test_encryption(settings_store):
email_verified=True,
)
await settings_store.store(settings)
from sqlalchemy import select
async with settings_store.a_session_maker() as session:
result = await session.execute(
select(UserSettings).filter(
with settings_store.session_maker() as session:
stored = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
)
.first()
)
stored = result.scalars().first()
# The stored key should be encrypted
assert stored.llm_api_key != 'secret_key'
# But we should be able to decrypt it when loading
@@ -188,7 +182,7 @@ async def test_encryption(settings_store):
@pytest.mark.asyncio
async def test_ensure_api_key_keeps_valid_key(mock_config):
"""When the existing key is valid, it should be kept unchanged."""
store = SaasSettingsStore('test-user-id-123', mock_config)
store = SaasSettingsStore('test-user-id-123', MagicMock(), mock_config)
existing_key = 'sk-existing-key'
item = DataSettings(
llm_model='openhands/gpt-4', llm_api_key=SecretStr(existing_key)
@@ -211,7 +205,7 @@ async def test_ensure_api_key_generates_new_key_when_verification_fails(
mock_config,
):
"""When verification fails, a new key should be generated."""
store = SaasSettingsStore('test-user-id-123', mock_config)
store = SaasSettingsStore('test-user-id-123', MagicMock(), mock_config)
new_key = 'sk-new-key'
item = DataSettings(
llm_model='openhands/gpt-4', llm_api_key=SecretStr('sk-invalid-key')

Some files were not shown because too many files have changed in this diff Show More