mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
3 Commits
optimize-d
...
fix/git-ap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4f7f07d5d | ||
|
|
a34dc949ce | ||
|
|
80e4fe1226 |
@@ -193,20 +193,14 @@ class GithubManager(Manager):
|
||||
github_view.installation_id
|
||||
)
|
||||
# Store the installation token
|
||||
await self.token_manager.store_org_token(
|
||||
self.token_manager.store_org_token(
|
||||
github_view.installation_id, installation_token
|
||||
)
|
||||
# Add eyes reaction to acknowledge we've read the request
|
||||
self._add_reaction(github_view, 'eyes', installation_token)
|
||||
await self.start_job(github_view)
|
||||
|
||||
async def send_message(self, message: str, github_view: ResolverViewInterface):
|
||||
"""Send a message to GitHub.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
github_view: The GitHub view object containing issue/PR/comment info
|
||||
"""
|
||||
async def send_message(self, message: Message, github_view: ResolverViewInterface):
|
||||
installation_token = self.token_manager.load_org_token(
|
||||
github_view.installation_id
|
||||
)
|
||||
@@ -214,12 +208,14 @@ class GithubManager(Manager):
|
||||
logger.warning('Missing installation token')
|
||||
return
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(github_view, GithubInlinePRComment):
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
pr = repo.get_pull(github_view.issue_number)
|
||||
pr.create_review_comment_reply(
|
||||
comment_id=github_view.comment_id, body=message
|
||||
comment_id=github_view.comment_id, body=outgoing_message
|
||||
)
|
||||
|
||||
elif (
|
||||
@@ -230,7 +226,7 @@ class GithubManager(Manager):
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
issue = repo.get_issue(number=github_view.issue_number)
|
||||
issue.create_comment(message)
|
||||
issue.create_comment(outgoing_message)
|
||||
|
||||
else:
|
||||
logger.warning('Unsupported location')
|
||||
@@ -249,7 +245,7 @@ class GithubManager(Manager):
|
||||
)
|
||||
|
||||
try:
|
||||
msg_info: str = ''
|
||||
msg_info = None
|
||||
|
||||
try:
|
||||
user_info = github_view.user_info
|
||||
@@ -365,13 +361,15 @@ class GithubManager(Manager):
|
||||
|
||||
msg_info = get_session_expired_message(user_info.username)
|
||||
|
||||
await self.send_message(msg_info, github_view)
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, github_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Github]: Error starting job')
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', github_view
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
)
|
||||
await self.send_message(msg, github_view)
|
||||
|
||||
try:
|
||||
await self.data_collector.save_data(github_view)
|
||||
|
||||
@@ -14,6 +14,7 @@ from integrations.solvability.models.summary import SolvabilitySummary
|
||||
from integrations.utils import ENABLE_SOLVABILITY_ANALYSIS
|
||||
from pydantic import ValidationError
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
@@ -89,6 +90,7 @@ async def summarize_issue_solvability(
|
||||
# Grab the user's information so we can load their LLM configuration
|
||||
store = SaasSettingsStore(
|
||||
user_id=github_view.user_info.keycloak_user_id,
|
||||
session_maker=session_maker,
|
||||
config=get_config(),
|
||||
)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from jinja2 import Environment
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
@@ -152,7 +153,9 @@ class GithubIssue(ResolverViewInterface):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
@@ -121,11 +121,12 @@ class GitlabManager(Manager):
|
||||
# Check if the user has write access to the repository
|
||||
return has_write_access
|
||||
|
||||
async def send_message(self, message: str, gitlab_view: ResolverViewInterface):
|
||||
"""Send a message to GitLab based on the view type.
|
||||
async def send_message(self, message: Message, gitlab_view: ResolverViewInterface):
|
||||
"""
|
||||
Send a message to GitLab based on the view type.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
message: The message to send
|
||||
gitlab_view: The GitLab view object containing issue/PR/comment info
|
||||
"""
|
||||
keycloak_user_id = gitlab_view.user_info.keycloak_user_id
|
||||
@@ -137,6 +138,8 @@ class GitlabManager(Manager):
|
||||
external_auth_id=keycloak_user_id
|
||||
)
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(gitlab_view, GitlabInlineMRComment) or isinstance(
|
||||
gitlab_view, GitlabMRComment
|
||||
):
|
||||
@@ -144,7 +147,7 @@ class GitlabManager(Manager):
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
message,
|
||||
message.message,
|
||||
)
|
||||
|
||||
elif isinstance(gitlab_view, GitlabIssueComment):
|
||||
@@ -152,14 +155,14 @@ class GitlabManager(Manager):
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
message,
|
||||
outgoing_message,
|
||||
)
|
||||
elif isinstance(gitlab_view, GitlabIssue):
|
||||
await gitlab_service.reply_to_issue(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
None, # no discussion id, issue is tagged
|
||||
message,
|
||||
outgoing_message,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -259,10 +262,12 @@ class GitlabManager(Manager):
|
||||
msg_info = get_session_expired_message(user_info.username)
|
||||
|
||||
# Send the acknowledgment message
|
||||
await self.send_message(msg_info, gitlab_view)
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'[GitLab] Error starting job: {str(e)}')
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', gitlab_view
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
|
||||
@@ -6,6 +6,7 @@ from integrations.utils import HOST, get_oh_labels, has_exact_mention
|
||||
from jinja2 import Environment
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -77,7 +78,9 @@ class GitlabIssue(ResolverViewInterface):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
@@ -446,5 +449,3 @@ class GitlabFactory:
|
||||
previous_comments=[],
|
||||
is_mr=True,
|
||||
)
|
||||
|
||||
raise ValueError(f'Unhandled GitLab webhook event: {message}')
|
||||
|
||||
@@ -341,25 +341,17 @@ class JiraManager(Manager):
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: str,
|
||||
message: Message,
|
||||
issue_key: str,
|
||||
jira_cloud_id: str,
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
):
|
||||
"""Send a comment to a Jira issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_key: The Jira issue key (e.g., 'PROJ-123')
|
||||
jira_cloud_id: The Jira Cloud ID
|
||||
svc_acc_email: Service account email for authentication
|
||||
svc_acc_api_key: Service account API key for authentication
|
||||
"""
|
||||
"""Send a comment to a Jira issue."""
|
||||
url = (
|
||||
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
|
||||
)
|
||||
data = {'body': message}
|
||||
data = {'body': message.message}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(
|
||||
url, auth=(svc_acc_email, svc_acc_api_key), json=data
|
||||
@@ -374,7 +366,7 @@ class JiraManager(Manager):
|
||||
view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
msg,
|
||||
self.create_outgoing_message(msg=msg),
|
||||
issue_key=view.payload.issue_key,
|
||||
jira_cloud_id=view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=view.jira_workspace.svc_acc_email,
|
||||
@@ -396,7 +388,7 @@ class JiraManager(Manager):
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
error_msg,
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
issue_key=payload.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
|
||||
@@ -212,6 +212,8 @@ class JiraPayloadParser:
|
||||
missing.append('issue.id')
|
||||
if not issue_key:
|
||||
missing.append('issue.key')
|
||||
if not user_email:
|
||||
missing.append('user.emailAddress')
|
||||
if not display_name:
|
||||
missing.append('user.displayName')
|
||||
if not account_id:
|
||||
|
||||
@@ -418,7 +418,7 @@ class JiraDcManager(Manager):
|
||||
jira_dc_view.jira_dc_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
msg_info,
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
@@ -456,19 +456,12 @@ class JiraDcManager(Manager):
|
||||
return title, description
|
||||
|
||||
async def send_message(
|
||||
self, message: str, issue_key: str, base_api_url: str, svc_acc_api_key: str
|
||||
self, message: Message, issue_key: str, base_api_url: str, svc_acc_api_key: str
|
||||
):
|
||||
"""Send message/comment to Jira DC issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_key: The Jira issue key (e.g., 'PROJ-123')
|
||||
base_api_url: The base API URL for the Jira DC instance
|
||||
svc_acc_api_key: Service account API key for authentication
|
||||
"""
|
||||
"""Send message/comment to Jira DC issue."""
|
||||
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
|
||||
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||
data = {'body': message}
|
||||
data = {'body': message.message}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
@@ -488,7 +481,7 @@ class JiraDcManager(Manager):
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
error_msg,
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
issue_key=job_context.issue_key,
|
||||
base_api_url=job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
@@ -509,7 +502,7 @@ class JiraDcManager(Manager):
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
comment_msg,
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
|
||||
@@ -19,7 +19,7 @@ class JiraDcViewInterface(ABC):
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_dc_instructions.j2')
|
||||
@@ -61,7 +61,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
@@ -113,7 +113,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_dc_existing_conversation.j2')
|
||||
@@ -167,7 +167,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = await self._get_instructions(jinja_env)
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
|
||||
@@ -408,7 +408,7 @@ class LinearManager(Manager):
|
||||
linear_view.linear_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
msg_info,
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
@@ -473,14 +473,8 @@ class LinearManager(Manager):
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(self, message: str, issue_id: str, api_key: str):
|
||||
"""Send message/comment to Linear issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_id: The Linear issue ID to comment on
|
||||
api_key: The Linear API key for authentication
|
||||
"""
|
||||
async def send_message(self, message: Message, issue_id: str, api_key: str):
|
||||
"""Send message/comment to Linear issue."""
|
||||
query = """
|
||||
mutation CommentCreate($input: CommentCreateInput!) {
|
||||
commentCreate(input: $input) {
|
||||
@@ -491,7 +485,7 @@ class LinearManager(Manager):
|
||||
}
|
||||
}
|
||||
"""
|
||||
variables = {'input': {'issueId': issue_id, 'body': message}}
|
||||
variables = {'input': {'issueId': issue_id, 'body': message.message}}
|
||||
return await self._query_api(query, variables, api_key)
|
||||
|
||||
async def _send_error_comment(
|
||||
@@ -504,7 +498,9 @@ class LinearManager(Manager):
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(error_msg, issue_id, api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg), issue_id, api_key
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Failed to send error comment: {str(e)}')
|
||||
|
||||
@@ -521,7 +517,7 @@ class LinearManager(Manager):
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
comment_msg,
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ class LinearViewInterface(ABC):
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class LinearNewConversationView(LinearViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('linear_instructions.j2')
|
||||
@@ -58,7 +58,7 @@ class LinearNewConversationView(LinearViewInterface):
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
@@ -110,7 +110,7 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('linear_existing_conversation.j2')
|
||||
@@ -164,7 +164,7 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = await self._get_instructions(jinja_env)
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from integrations.models import Message, SourceType
|
||||
|
||||
@@ -13,15 +12,14 @@ class Manager(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_message(self, message: str, *args: Any, **kwargs: Any):
|
||||
"""Send message to integration from OpenHands server.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string).
|
||||
"""
|
||||
def send_message(self, message: Message):
|
||||
"Send message to integration from Openhands server"
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self):
|
||||
"Kick off a job with openhands agent"
|
||||
raise NotImplementedError
|
||||
|
||||
def create_outgoing_message(self, msg: str | dict, ephemeral: bool = False):
|
||||
return Message(source=SourceType.OPENHANDS, message=msg, ephemeral=ephemeral)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -17,16 +16,8 @@ class SourceType(str, Enum):
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Message model for incoming webhook payloads from integrations.
|
||||
|
||||
Note: This model is intended for INCOMING messages only.
|
||||
For outgoing messages (e.g., sending comments to GitHub/GitLab),
|
||||
pass strings directly to the send_message methods instead of
|
||||
wrapping them in a Message object.
|
||||
"""
|
||||
|
||||
source: SourceType
|
||||
message: dict[str, Any]
|
||||
message: str | dict
|
||||
ephemeral: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from integrations.manager import Manager
|
||||
@@ -23,8 +22,7 @@ from server.constants import SLACK_CLIENT_ID
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from slack_sdk.oauth import AuthorizeUrlGenerator
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.slack_user import SlackUser
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -65,11 +63,12 @@ class SlackManager(Manager):
|
||||
) -> tuple[SlackUser | None, UserAuth | None]:
|
||||
# We get the user and correlate them back to a user in OpenHands - if we can
|
||||
slack_user = None
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SlackUser).where(SlackUser.slack_user_id == slack_user_id)
|
||||
with session_maker() as session:
|
||||
slack_user = (
|
||||
session.query(SlackUser)
|
||||
.filter(SlackUser.slack_user_id == slack_user_id)
|
||||
.first()
|
||||
)
|
||||
slack_user = result.scalar_one_or_none()
|
||||
|
||||
# slack_view.slack_to_openhands_user = slack_user # attach user auth info to view
|
||||
|
||||
@@ -203,7 +202,9 @@ class SlackManager(Manager):
|
||||
msg = self.login_link.format(link)
|
||||
|
||||
logger.info('slack_not_yet_authenticated')
|
||||
await self.send_message(msg, slack_view, ephemeral=True)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg, ephemeral=True), slack_view
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, slack_view):
|
||||
@@ -211,40 +212,27 @@ class SlackManager(Manager):
|
||||
|
||||
await self.start_job(slack_view)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: str | dict[str, Any],
|
||||
slack_view: SlackViewInterface,
|
||||
ephemeral: bool = False,
|
||||
):
|
||||
"""Send a message to Slack.
|
||||
|
||||
Args:
|
||||
message: The message content. Can be a string (for simple text) or
|
||||
a dict with 'text' and 'blocks' keys (for structured messages).
|
||||
slack_view: The Slack view object containing channel/thread info.
|
||||
ephemeral: If True, send as an ephemeral message visible only to the user.
|
||||
"""
|
||||
async def send_message(self, message: Message, slack_view: SlackViewInterface):
|
||||
client = AsyncWebClient(token=slack_view.bot_access_token)
|
||||
if ephemeral and isinstance(message, str):
|
||||
if message.ephemeral and isinstance(message.message, str):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message,
|
||||
markdown_text=message.message,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
)
|
||||
elif ephemeral and isinstance(message, dict):
|
||||
elif message.ephemeral and isinstance(message.message, dict):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
text=message['text'],
|
||||
blocks=message['blocks'],
|
||||
text=message.message['text'],
|
||||
blocks=message.message['blocks'],
|
||||
)
|
||||
else:
|
||||
await client.chat_postMessage(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message,
|
||||
markdown_text=message.message,
|
||||
thread_ts=slack_view.message_ts,
|
||||
)
|
||||
|
||||
@@ -291,7 +279,10 @@ class SlackManager(Manager):
|
||||
repos, slack_view.message_ts, slack_view.thread_ts
|
||||
),
|
||||
}
|
||||
await self.send_message(repo_selection_msg, slack_view, ephemeral=True)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(repo_selection_msg, ephemeral=True),
|
||||
slack_view,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@@ -377,10 +368,9 @@ class SlackManager(Manager):
|
||||
except StartingConvoException as e:
|
||||
msg_info = str(e)
|
||||
|
||||
await self.send_message(msg_info, slack_view)
|
||||
await self.send_message(self.create_outgoing_message(msg_info), slack_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Slack]: Error starting job')
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', slack_view
|
||||
)
|
||||
msg = 'Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(self.create_outgoing_message(msg), slack_view)
|
||||
|
||||
@@ -24,7 +24,7 @@ class SlackViewInterface(SummaryExtractionTracker, ABC):
|
||||
v1_enabled: bool
|
||||
|
||||
@abstractmethod
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class SlackUnkownUserView(SlackViewInterface):
|
||||
team_id: str
|
||||
v1_enabled: bool
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment):
|
||||
@@ -118,7 +118,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
return block['user_id']
|
||||
return ''
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
|
||||
@@ -242,9 +242,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
self, jinja: Environment, provider_tokens, user_secrets
|
||||
) -> None:
|
||||
"""Create conversation using the legacy V0 system."""
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja
|
||||
)
|
||||
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||
|
||||
# Determine git provider from repository
|
||||
git_provider = None
|
||||
@@ -275,9 +273,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
|
||||
async def _create_v1_conversation(self, jinja: Environment) -> None:
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja
|
||||
)
|
||||
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||
|
||||
# Create the initial message request
|
||||
initial_message = SendMessageRequest(
|
||||
@@ -350,7 +346,7 @@ class SlackNewConversationFromRepoFormView(SlackNewConversationView):
|
||||
class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
slack_conversation: SlackConversation
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
@@ -405,7 +401,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
instructions, _ = await self._get_instructions(jinja)
|
||||
instructions, _ = self._get_instructions(jinja)
|
||||
user_msg = MessageAction(content=instructions)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg)
|
||||
@@ -473,7 +469,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
agent_server_url = get_agent_server_url_from_sandbox(running_sandbox)
|
||||
|
||||
# 4. Prepare the message content
|
||||
user_msg, _ = await self._get_instructions(jinja)
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
|
||||
# 5. Create the message request
|
||||
send_message_request = SendMessageRequest(
|
||||
|
||||
@@ -42,11 +42,11 @@ async def store_repositories_in_db(repos: list[Repository], user_id: str) -> Non
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
await repo_store.store_projects(stored_repos)
|
||||
repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
await user_repo_store.store_user_repo_mappings(user_repos)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
|
||||
@@ -3,8 +3,8 @@ from uuid import UUID
|
||||
import stripe
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
@@ -15,10 +15,12 @@ stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
async with a_session_maker() as session:
|
||||
stmt = select(StripeCustomer).where(StripeCustomer.org_id == org_id)
|
||||
result = await session.execute(stmt)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.org_id == org_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer:
|
||||
return stripe_customer.stripe_customer_id
|
||||
|
||||
@@ -72,7 +74,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id=user_id,
|
||||
@@ -80,7 +82,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
stripe_customer_id=customer.id,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
@@ -106,27 +108,26 @@ async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(user_id: str, org: Org):
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StripeCustomer).where(StripeCustomer.keycloak_user_id == user_id)
|
||||
)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
async def migrate_customer(session: Session, user_id: str, org: Org):
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ class ResolverViewInterface(SummaryExtractionTracker):
|
||||
is_public_repo: bool
|
||||
raw_payload: dict
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
17
enterprise/poetry.lock
generated
17
enterprise/poetry.lock
generated
@@ -1591,9 +1591,6 @@ files = [
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7"},
|
||||
{file = "cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d"},
|
||||
]
|
||||
|
||||
@@ -6171,23 +6168,23 @@ opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
|
||||
pathspec = ">=0.12.1"
|
||||
pexpect = "*"
|
||||
pg8000 = ">=1.31.5"
|
||||
pillow = ">=12.1.1"
|
||||
pillow = ">=11.3"
|
||||
playwright = ">=1.55"
|
||||
poetry = ">=2.1.2"
|
||||
prompt-toolkit = ">=3.0.50"
|
||||
protobuf = ">=5.29.6,<6"
|
||||
protobuf = ">=5,<6"
|
||||
psutil = "*"
|
||||
pybase62 = ">=1"
|
||||
pygithub = ">=2.5"
|
||||
pyjwt = ">=2.9"
|
||||
pylatexenc = "*"
|
||||
pypdf = ">=6.7.2"
|
||||
pypdf = ">=6"
|
||||
python-docx = "*"
|
||||
python-dotenv = "*"
|
||||
python-frontmatter = ">=1.1"
|
||||
python-jose = {version = ">=3.3", extras = ["cryptography"]}
|
||||
python-json-logger = ">=3.2.1"
|
||||
python-multipart = ">=0.0.22"
|
||||
python-multipart = "*"
|
||||
python-pptx = "*"
|
||||
python-socketio = "5.14"
|
||||
pythonnet = "*"
|
||||
@@ -6200,7 +6197,7 @@ setuptools = ">=78.1.1"
|
||||
shellingham = ">=1.5.4"
|
||||
sqlalchemy = {version = ">=2.0.40", extras = ["asyncio"]}
|
||||
sse-starlette = ">=3.0.2"
|
||||
starlette = ">=0.49.1"
|
||||
starlette = ">=0.48"
|
||||
tenacity = ">=8.5,<10"
|
||||
termcolor = "*"
|
||||
toml = "*"
|
||||
@@ -11964,7 +11961,7 @@ description = "Python for Window Extensions"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
|
||||
markers = "platform_system == \"Windows\" or sys_platform == \"win32\""
|
||||
files = [
|
||||
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
|
||||
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
|
||||
@@ -14912,4 +14909,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "ef037f6d6085d26166d35c56ce266439f8f1a4fea90bc43ccf15cfeaf116cae5"
|
||||
content-hash = "1cad6029269393af67155e930c72eae2c03da02e4b3a3699823f6168c14a4218"
|
||||
|
||||
@@ -49,7 +49,7 @@ prometheus-client = "^0.24.0"
|
||||
pandas = "^2.2.0"
|
||||
numpy = "^2.2.0"
|
||||
mcp = "^1.10.0"
|
||||
pillow = "^12.1.1"
|
||||
pillow = "^12.1.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.8.3"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from storage.blocked_email_domain_store import BlockedEmailDomainStore
|
||||
from storage.database import session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -22,7 +23,7 @@ class DomainBlocker:
|
||||
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
|
||||
return None
|
||||
|
||||
async def is_domain_blocked(self, email: str) -> bool:
|
||||
def is_domain_blocked(self, email: str) -> bool:
|
||||
"""Check if email domain is blocked by querying the database directly via SQL.
|
||||
|
||||
Supports blocking:
|
||||
@@ -44,7 +45,7 @@ class DomainBlocker:
|
||||
|
||||
try:
|
||||
# Query database directly via SQL to check if domain is blocked
|
||||
is_blocked = await self.store.is_domain_blocked(domain)
|
||||
is_blocked = self.store.is_domain_blocked(domain)
|
||||
|
||||
if is_blocked:
|
||||
logger.warning(f'Email domain {domain} is blocked for email: {email}')
|
||||
@@ -62,5 +63,5 @@ class DomainBlocker:
|
||||
|
||||
|
||||
# Initialize store and domain blocker
|
||||
_store = BlockedEmailDomainStore()
|
||||
_store = BlockedEmailDomainStore(session_maker=session_maker)
|
||||
domain_blocker = DomainBlocker(store=_store)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from integrations.github.github_service import SaaSGitHubService
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_utils import user_verifier
|
||||
|
||||
from enterprise.server.auth.auth_utils import user_verifier
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_types import GitHubUser
|
||||
|
||||
|
||||
@@ -18,10 +18,9 @@ from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
from server.rate_limit import RateLimiter, create_redis_rate_limiter
|
||||
from sqlalchemy import delete, select
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
@@ -125,7 +124,7 @@ class SaasUserAuth(UserAuth):
|
||||
if secrets_store:
|
||||
return secrets_store
|
||||
user_id = await self.get_user_id()
|
||||
secrets_store = SaasSecretsStore(user_id, get_config())
|
||||
secrets_store = SaasSecretsStore(user_id, session_maker, get_config())
|
||||
self.secrets_store = secrets_store
|
||||
return secrets_store
|
||||
|
||||
@@ -162,13 +161,12 @@ class SaasUserAuth(UserAuth):
|
||||
|
||||
try:
|
||||
# TODO: I think we can do this in a single request if we refactor
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == self.user_id
|
||||
)
|
||||
with session_maker() as session:
|
||||
tokens = (
|
||||
session.query(AuthTokens)
|
||||
.where(AuthTokens.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
)
|
||||
tokens = result.scalars().all()
|
||||
|
||||
for token in tokens:
|
||||
idp_type = ProviderType(token.identity_provider)
|
||||
@@ -194,11 +192,11 @@ class SaasUserAuth(UserAuth):
|
||||
'idp_type': token.identity_provider,
|
||||
},
|
||||
)
|
||||
async with a_session_maker() as session:
|
||||
await session.execute(
|
||||
delete(AuthTokens).where(AuthTokens.id == token.id)
|
||||
)
|
||||
await session.commit()
|
||||
with session_maker() as session:
|
||||
session.query(AuthTokens).filter(
|
||||
AuthTokens.id == token.id
|
||||
).delete()
|
||||
session.commit()
|
||||
raise
|
||||
|
||||
self.provider_tokens = MappingProxyType(provider_tokens)
|
||||
@@ -212,7 +210,7 @@ class SaasUserAuth(UserAuth):
|
||||
if settings_store:
|
||||
return settings_store
|
||||
user_id = await self.get_user_id()
|
||||
settings_store = SaasSettingsStore(user_id, get_config())
|
||||
settings_store = SaasSettingsStore(user_id, session_maker, get_config())
|
||||
self.settings_store = settings_store
|
||||
return settings_store
|
||||
|
||||
@@ -280,7 +278,7 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
return None
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
user_id = api_key_store.validate_api_key(api_key)
|
||||
if not user_id:
|
||||
return None
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
@@ -329,7 +327,7 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
|
||||
email_verified = access_token_payload['email_verified']
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and await domain_blocker.is_domain_blocked(email):
|
||||
if email and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
|
||||
@@ -38,9 +38,9 @@ from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
from sqlalchemy import String as SQLString
|
||||
from sqlalchemy import select, type_coerce
|
||||
from sqlalchemy import type_coerce
|
||||
from storage.auth_token_store import AuthTokenStore
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.offline_token_store import OfflineTokenStore
|
||||
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
|
||||
@@ -783,24 +783,25 @@ class TokenManager:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def store_org_token(self, installation_id: int, installation_token: str):
|
||||
def store_org_token(self, installation_id: int, installation_token: str):
|
||||
"""Store a GitHub App installation token.
|
||||
|
||||
Args:
|
||||
installation_id: GitHub installation ID (integer or string)
|
||||
installation_token: The token to store
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
# Ensure installation_id is a string
|
||||
str_installation_id = str(installation_id)
|
||||
# Use type_coerce to ensure SQLAlchemy treats the parameter as a string
|
||||
result = await session.execute(
|
||||
select(GithubAppInstallation).filter(
|
||||
installation = (
|
||||
session.query(GithubAppInstallation)
|
||||
.filter(
|
||||
GithubAppInstallation.installation_id
|
||||
== type_coerce(str_installation_id, SQLString)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
installation = result.scalars().first()
|
||||
if installation:
|
||||
installation.encrypted_token = self.encrypt_text(installation_token)
|
||||
else:
|
||||
@@ -810,9 +811,9 @@ class TokenManager:
|
||||
encrypted_token=self.encrypt_text(installation_token),
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
async def load_org_token(self, installation_id: int) -> str | None:
|
||||
def load_org_token(self, installation_id: int) -> str | None:
|
||||
"""Load a GitHub App installation token.
|
||||
|
||||
Args:
|
||||
@@ -821,16 +822,17 @@ class TokenManager:
|
||||
Returns:
|
||||
The decrypted token if found, None otherwise
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
# Ensure installation_id is a string and use type_coerce
|
||||
str_installation_id = str(installation_id)
|
||||
result = await session.execute(
|
||||
select(GithubAppInstallation).filter(
|
||||
installation = (
|
||||
session.query(GithubAppInstallation)
|
||||
.filter(
|
||||
GithubAppInstallation.installation_id
|
||||
== type_coerce(str_installation_id, SQLString)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
installation = result.scalars().first()
|
||||
if not installation:
|
||||
return None
|
||||
token = self.decrypt_text(installation.encrypted_token)
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
|
||||
from integrations.github.github_manager import GithubManager
|
||||
from integrations.github.github_view import GithubViewType
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.utils import (
|
||||
extract_summary_from_conversation_manager,
|
||||
get_summary_instruction,
|
||||
@@ -34,12 +35,16 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
send_summary_instruction: bool = True
|
||||
|
||||
async def _send_message_to_github(self, message: str) -> None:
|
||||
"""Send a message to GitHub.
|
||||
"""
|
||||
Send a message to GitHub.
|
||||
|
||||
Args:
|
||||
message: The message content to send to GitHub
|
||||
"""
|
||||
try:
|
||||
# Create a message object for GitHub
|
||||
message_obj = Message(source=SourceType.OPENHANDS, message=message)
|
||||
|
||||
# Get the token manager
|
||||
token_manager = TokenManager()
|
||||
|
||||
@@ -48,8 +53,8 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
|
||||
github_manager = GithubManager(token_manager, GitHubDataCollector())
|
||||
|
||||
# Send the message directly as a string
|
||||
await github_manager.send_message(message, self.github_view)
|
||||
# Send the message
|
||||
await github_manager.send_message(message_obj, self.github_view)
|
||||
|
||||
logger.info(
|
||||
f'[GitHub] Sent summary message to {self.github_view.full_repo_name}#{self.github_view.issue_number}'
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
|
||||
from integrations.gitlab.gitlab_manager import GitlabManager
|
||||
from integrations.gitlab.gitlab_view import GitlabViewType
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.utils import (
|
||||
extract_summary_from_conversation_manager,
|
||||
get_summary_instruction,
|
||||
@@ -13,7 +14,7 @@ from storage.conversation_callback import (
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -27,7 +28,8 @@ gitlab_manager = GitlabManager(token_manager)
|
||||
|
||||
|
||||
class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""Processor for sending conversation summaries to GitLab.
|
||||
"""
|
||||
Processor for sending conversation summaries to GitLab.
|
||||
|
||||
This processor is used to send summaries of conversations to GitLab
|
||||
when agent state changes occur.
|
||||
@@ -37,18 +39,22 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
send_summary_instruction: bool = True
|
||||
|
||||
async def _send_message_to_gitlab(self, message: str) -> None:
|
||||
"""Send a message to GitLab.
|
||||
"""
|
||||
Send a message to GitLab.
|
||||
|
||||
Args:
|
||||
message: The message content to send to GitLab
|
||||
"""
|
||||
try:
|
||||
# Create a message object for GitHub
|
||||
message_obj = Message(source=SourceType.OPENHANDS, message=message)
|
||||
|
||||
# Get the token manager
|
||||
token_manager = TokenManager()
|
||||
gitlab_manager = GitlabManager(token_manager)
|
||||
|
||||
# Send the message directly as a string
|
||||
await gitlab_manager.send_message(message, self.gitlab_view)
|
||||
# Send the message
|
||||
await gitlab_manager.send_message(message_obj, self.gitlab_view)
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Sent summary message to {self.gitlab_view.full_repo_name}#{self.gitlab_view.issue_number}'
|
||||
@@ -105,9 +111,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
self.send_summary_instruction = False
|
||||
callback.set_processor(self)
|
||||
callback.updated_at = datetime.now()
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# Extract the summary from the event store
|
||||
@@ -126,9 +132,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
# Mark callback as completed status
|
||||
callback.status = CallbackStatus.COMPLETED
|
||||
callback.updated_at = datetime.now()
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
|
||||
@@ -37,7 +37,8 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace_name: str
|
||||
|
||||
async def _send_comment_to_jira(self, message: str) -> None:
|
||||
"""Send a comment to Jira issue.
|
||||
"""
|
||||
Send a comment to Jira issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Jira
|
||||
@@ -58,9 +59,8 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
|
||||
# Decrypt API key
|
||||
api_key = jira_manager.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
|
||||
# Send comment directly as a string
|
||||
await jira_manager.send_message(
|
||||
message,
|
||||
jira_manager.create_outgoing_message(msg=message),
|
||||
issue_key=self.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
|
||||
@@ -37,7 +37,8 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
|
||||
base_api_url: str
|
||||
|
||||
async def _send_comment_to_jira_dc(self, message: str) -> None:
|
||||
"""Send a comment to Jira DC issue.
|
||||
"""
|
||||
Send a comment to Jira DC issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Jira DC
|
||||
@@ -60,9 +61,8 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
# Send comment directly as a string
|
||||
await jira_dc_manager.send_message(
|
||||
message,
|
||||
jira_dc_manager.create_outgoing_message(msg=message),
|
||||
issue_key=self.issue_key,
|
||||
base_api_url=self.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
|
||||
@@ -36,7 +36,8 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace_name: str
|
||||
|
||||
async def _send_comment_to_linear(self, message: str) -> None:
|
||||
"""Send a comment to Linear issue.
|
||||
"""
|
||||
Send a comment to Linear issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Linear
|
||||
@@ -59,9 +60,9 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
# Send comment directly as a string
|
||||
# Send comment
|
||||
await linear_manager.send_message(
|
||||
message,
|
||||
linear_manager.create_outgoing_message(msg=message),
|
||||
self.issue_id,
|
||||
api_key,
|
||||
)
|
||||
|
||||
@@ -26,7 +26,8 @@ slack_manager = SlackManager(token_manager)
|
||||
|
||||
|
||||
class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""Processor for sending conversation summaries to Slack.
|
||||
"""
|
||||
Processor for sending conversation summaries to Slack.
|
||||
|
||||
This processor is used to send summaries of conversations to Slack channels
|
||||
when agent state changes occur.
|
||||
@@ -40,13 +41,14 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
last_user_msg_id: int | None = None
|
||||
|
||||
async def _send_message_to_slack(self, message: str) -> None:
|
||||
"""Send a message to Slack.
|
||||
"""
|
||||
Send a message to Slack using the conversation_manager's send_to_event_stream method.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Slack
|
||||
"""
|
||||
try:
|
||||
# Create a message object for Slack view creation (incoming message format)
|
||||
# Create a message object for Slack
|
||||
message_obj = Message(
|
||||
source=SourceType.SLACK,
|
||||
message={
|
||||
@@ -65,8 +67,9 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
slack_view = SlackFactory.create_slack_view_from_payload(
|
||||
message_obj, slack_user, saas_user_auth
|
||||
)
|
||||
# Send the message directly as a string
|
||||
await slack_manager.send_message(message, slack_view)
|
||||
await slack_manager.send_message(
|
||||
slack_manager.create_outgoing_message(message), slack_view
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Slack] Sent summary message to channel {self.channel_id} '
|
||||
|
||||
@@ -251,7 +251,7 @@ async def delete_api_key(
|
||||
)
|
||||
|
||||
# Delete the key
|
||||
success = await api_key_store.delete_api_key_by_id(key_id)
|
||||
success = api_key_store.delete_api_key_by_id(key_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -34,8 +34,7 @@ from server.services.org_invitation_service import (
|
||||
OrgInvitationService,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
@@ -271,7 +270,7 @@ async def keycloak_callback(
|
||||
# Fail open - continue with login if reCAPTCHA service unavailable
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and await domain_blocker.is_domain_blocked(email):
|
||||
if email and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
)
|
||||
@@ -611,20 +610,17 @@ async def accept_tos(request: Request):
|
||||
|
||||
# Update user settings with TOS acceptance
|
||||
accepted_tos: datetime = datetime.now(timezone.utc)
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
with session_maker() as session:
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
await session.rollback()
|
||||
session.rollback()
|
||||
logger.error('User for {user_id} not found.')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User does not exist'},
|
||||
)
|
||||
user.accepted_tos = accepted_tos
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
|
||||
@@ -11,10 +11,9 @@ from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
@@ -107,17 +106,16 @@ async def get_subscription_access(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> SubscriptionAccessResponse | None:
|
||||
"""Get details of the currently valid subscription for the user."""
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
result = await session.execute(
|
||||
select(SubscriptionAccess).where(
|
||||
SubscriptionAccess.status == 'ACTIVE',
|
||||
SubscriptionAccess.user_id == user_id,
|
||||
SubscriptionAccess.start_at <= now,
|
||||
SubscriptionAccess.end_at >= now,
|
||||
)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.first()
|
||||
)
|
||||
subscription_access = result.scalar_one_or_none()
|
||||
if not subscription_access:
|
||||
return None
|
||||
return SubscriptionAccessResponse(
|
||||
@@ -199,7 +197,7 @@ async def create_checkout_session(
|
||||
'checkout_session_id': checkout_session.id,
|
||||
},
|
||||
)
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
@@ -208,7 +206,7 @@ async def create_checkout_session(
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
@@ -217,14 +215,13 @@ async def create_checkout_session(
|
||||
@billing_router.get('/success')
|
||||
async def success_callback(session_id: str, request: Request):
|
||||
# We can't use the auth cookie because of SameSite=strict
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(
|
||||
BillingSession.id == session_id,
|
||||
BillingSession.status == 'in_progress',
|
||||
)
|
||||
with session_maker() as session:
|
||||
billing_session = (
|
||||
session.query(BillingSession)
|
||||
.filter(BillingSession.id == session_id)
|
||||
.filter(BillingSession.status == 'in_progress')
|
||||
.first()
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
|
||||
if billing_session is None:
|
||||
# Hopefully this never happens - we get a redirect from stripe where the session does not exist
|
||||
@@ -256,8 +253,7 @@ async def success_callback(session_id: str, request: Request):
|
||||
user_team_info, billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
|
||||
result = await session.execute(select(Org).where(Org.id == user.current_org_id))
|
||||
org = result.scalar_one_or_none()
|
||||
org = session.query(Org).filter(Org.id == user.current_org_id).first()
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
@@ -283,7 +279,7 @@ async def success_callback(session_id: str, request: Request):
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
|
||||
@@ -293,14 +289,13 @@ async def success_callback(session_id: str, request: Request):
|
||||
# Callback endpoint for cancelled Stripe payments - updates billing session status
|
||||
@billing_router.get('/cancel')
|
||||
async def cancel_callback(session_id: str, request: Request):
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(
|
||||
BillingSession.id == session_id,
|
||||
BillingSession.status == 'in_progress',
|
||||
)
|
||||
with session_maker() as session:
|
||||
billing_session = (
|
||||
session.query(BillingSession)
|
||||
.filter(BillingSession.id == session_id)
|
||||
.filter(BillingSession.status == 'in_progress')
|
||||
.first()
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
if billing_session:
|
||||
logger.info(
|
||||
'stripe_checkout_cancel',
|
||||
@@ -312,7 +307,7 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
billing_session.status = 'cancelled'
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
@@ -11,6 +11,7 @@ from openhands.events.event_store import EventStore
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.shared import file_store
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint
|
||||
# is protected. The actual protection is provided by SetAuthCookieMiddleware
|
||||
@@ -36,19 +37,23 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
"""
|
||||
|
||||
# Verify the conversation belongs to the user
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
)
|
||||
metadata = result.scalars().first()
|
||||
if not metadata:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Conversation {conversation_id} not found',
|
||||
def _verify_conversation():
|
||||
with session_maker() as session:
|
||||
metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not metadata:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Conversation {conversation_id} not found',
|
||||
)
|
||||
|
||||
await call_sync_from_async(_verify_conversation)
|
||||
|
||||
# Create an event store to access the events directly
|
||||
# This works even when the conversation is not running
|
||||
@@ -98,9 +103,12 @@ async def submit_conversation_feedback(feedback: FeedbackRequest):
|
||||
)
|
||||
|
||||
# Add to database
|
||||
async with a_session_maker() as session:
|
||||
session.add(new_feedback)
|
||||
await session.commit()
|
||||
def _save_feedback():
|
||||
with session_maker() as session:
|
||||
session.add(new_feedback)
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_feedback)
|
||||
|
||||
return {'status': 'success', 'message': 'Feedback submitted successfully'}
|
||||
|
||||
@@ -119,27 +127,30 @@ async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_us
|
||||
return {}
|
||||
|
||||
# Query for existing feedback for all events
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ConversationFeedback).where(
|
||||
ConversationFeedback.conversation_id == conversation_id,
|
||||
ConversationFeedback.event_id.in_(event_ids),
|
||||
def _check_feedback():
|
||||
with session_maker() as session:
|
||||
result = session.execute(
|
||||
select(ConversationFeedback).where(
|
||||
ConversationFeedback.conversation_id == conversation_id,
|
||||
ConversationFeedback.event_id.in_(event_ids),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Create a mapping of event_id to feedback
|
||||
feedback_map = {
|
||||
feedback.event_id: {
|
||||
'exists': True,
|
||||
'rating': feedback.rating,
|
||||
'reason': feedback.reason,
|
||||
# Create a mapping of event_id to feedback
|
||||
feedback_map = {
|
||||
feedback.event_id: {
|
||||
'exists': True,
|
||||
'rating': feedback.rating,
|
||||
'reason': feedback.reason,
|
||||
}
|
||||
for feedback in result.scalars()
|
||||
}
|
||||
for feedback in result.scalars()
|
||||
}
|
||||
|
||||
# Build response including all events
|
||||
response = {}
|
||||
for event_id in event_ids:
|
||||
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
|
||||
# Build response including all events
|
||||
response = {}
|
||||
for event_id in event_ids:
|
||||
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
|
||||
|
||||
return response
|
||||
return response
|
||||
|
||||
return await call_sync_from_async(_check_feedback)
|
||||
|
||||
@@ -308,11 +308,10 @@ async def jira_events(
|
||||
logger.info(f'Processing new Jira webhook event: {signature}')
|
||||
redis_client.setex(key, 300, '1')
|
||||
|
||||
# Process the webhook in background after returning response.
|
||||
# Note: For async functions, BackgroundTasks runs them in the same event loop
|
||||
# (not a thread pool), so asyncpg connections work correctly.
|
||||
# Process the webhook
|
||||
message_payload = {'payload': payload}
|
||||
message = Message(source=SourceType.JIRA, message=message_payload)
|
||||
|
||||
background_tasks.add_task(jira_manager.receive_message, message)
|
||||
|
||||
return JSONResponse({'success': True})
|
||||
|
||||
@@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -53,7 +54,7 @@ class DeviceTokenErrorResponse(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
oauth_device_router = APIRouter(prefix='/oauth/device')
|
||||
device_code_store = DeviceCodeStore()
|
||||
device_code_store = DeviceCodeStore(session_maker)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -89,7 +90,7 @@ async def device_authorization(
|
||||
) -> DeviceAuthorizationResponse:
|
||||
"""Start device flow by generating device and user codes."""
|
||||
try:
|
||||
device_code_entry = await device_code_store.create_device_code(
|
||||
device_code_entry = device_code_store.create_device_code(
|
||||
expires_in=DEVICE_CODE_EXPIRES_IN,
|
||||
)
|
||||
|
||||
@@ -124,7 +125,7 @@ async def device_authorization(
|
||||
async def device_token(device_code: str = Form(...)):
|
||||
"""Poll for a token until the user authorizes or the code expires."""
|
||||
try:
|
||||
device_code_entry = await device_code_store.get_by_device_code(device_code)
|
||||
device_code_entry = device_code_store.get_by_device_code(device_code)
|
||||
|
||||
if not device_code_entry:
|
||||
return _oauth_error(
|
||||
@@ -137,9 +138,7 @@ async def device_token(device_code: str = Form(...)):
|
||||
is_too_fast, current_interval = device_code_entry.check_rate_limit()
|
||||
if is_too_fast:
|
||||
# Update poll time and increase interval
|
||||
await device_code_store.update_poll_time(
|
||||
device_code, increase_interval=True
|
||||
)
|
||||
device_code_store.update_poll_time(device_code, increase_interval=True)
|
||||
logger.warning(
|
||||
'Client polling too fast, returning slow_down error',
|
||||
extra={
|
||||
@@ -155,7 +154,7 @@ async def device_token(device_code: str = Form(...)):
|
||||
)
|
||||
|
||||
# Update poll time for successful rate limit check
|
||||
await device_code_store.update_poll_time(device_code, increase_interval=False)
|
||||
device_code_store.update_poll_time(device_code, increase_interval=False)
|
||||
|
||||
if device_code_entry.is_expired():
|
||||
return _oauth_error(
|
||||
@@ -182,7 +181,7 @@ async def device_token(device_code: str = Form(...)):
|
||||
# Retrieve the specific API key for this device using the user_code
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})'
|
||||
device_api_key = await api_key_store.retrieve_api_key_by_name(
|
||||
device_api_key = api_key_store.retrieve_api_key_by_name(
|
||||
device_code_entry.keycloak_user_id, device_key_name
|
||||
)
|
||||
|
||||
@@ -239,7 +238,7 @@ async def device_verification_authenticated(
|
||||
)
|
||||
|
||||
# Validate device code
|
||||
device_code_entry = await device_code_store.get_by_user_code(user_code)
|
||||
device_code_entry = device_code_store.get_by_user_code(user_code)
|
||||
if not device_code_entry:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -253,7 +252,7 @@ async def device_verification_authenticated(
|
||||
)
|
||||
|
||||
# First, authorize the device code
|
||||
success = await device_code_store.authorize_device_code(
|
||||
success = device_code_store.authorize_device_code(
|
||||
user_code=user_code,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -290,7 +289,7 @@ async def device_verification_authenticated(
|
||||
# Clean up: revert the device authorization since API key creation failed
|
||||
# This prevents the device from being in an authorized state without an API key
|
||||
try:
|
||||
await device_code_store.deny_device_code(user_code)
|
||||
device_code_store.deny_device_code(user_code)
|
||||
logger.info(
|
||||
'Reverted device authorization due to API key creation failure',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from sqlalchemy.sql import text
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.redis import create_redis_client
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -9,11 +9,11 @@ readiness_router = APIRouter()
|
||||
|
||||
|
||||
@readiness_router.get('/ready')
|
||||
async def is_ready():
|
||||
def is_ready():
|
||||
# Check database connection
|
||||
try:
|
||||
async with a_session_maker() as session:
|
||||
await session.execute(text('SELECT 1'))
|
||||
with session_maker() as session:
|
||||
session.execute(text('SELECT 1'))
|
||||
except Exception as e:
|
||||
logger.error(f'Database check failed: {str(e)}')
|
||||
raise HTTPException(
|
||||
|
||||
@@ -388,4 +388,5 @@ async def _check_idp(
|
||||
access_token.get_secret_value(), ProviderType(idp)
|
||||
):
|
||||
return default_value
|
||||
|
||||
return None
|
||||
|
||||
@@ -4,14 +4,13 @@ import pickle
|
||||
from datetime import datetime
|
||||
|
||||
from server.logger import logger
|
||||
from sqlalchemy import and_, select
|
||||
from storage.conversation_callback import (
|
||||
CallbackStatus,
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
@@ -80,16 +79,15 @@ async def invoke_conversation_callbacks(
|
||||
conversation_id: The conversation ID to process callbacks for
|
||||
observation: The AgentStateChangedObservation that triggered the callback
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ConversationCallback).filter(
|
||||
and_(
|
||||
ConversationCallback.conversation_id == conversation_id,
|
||||
ConversationCallback.status == CallbackStatus.ACTIVE,
|
||||
)
|
||||
with session_maker() as session:
|
||||
callbacks = (
|
||||
session.query(ConversationCallback)
|
||||
.filter(
|
||||
ConversationCallback.conversation_id == conversation_id,
|
||||
ConversationCallback.status == CallbackStatus.ACTIVE,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
callbacks = result.scalars().all()
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
@@ -117,7 +115,7 @@ async def invoke_conversation_callbacks(
|
||||
callback.status = CallbackStatus.ERROR
|
||||
callback.updated_at = datetime.now()
|
||||
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
|
||||
def update_conversation_metadata(conversation_id: str, content: dict):
|
||||
|
||||
@@ -2,10 +2,6 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from server.verified_models.verified_model_models import (
|
||||
VerifiedModel,
|
||||
VerifiedModelPage,
|
||||
)
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
@@ -22,6 +18,10 @@ from sqlalchemy import (
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.base import Base
|
||||
|
||||
from enterprise.server.verified_models.verified_model_models import (
|
||||
VerifiedModel,
|
||||
VerifiedModelPage,
|
||||
)
|
||||
from openhands.app_server.config import depends_db_session
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@@ -5,16 +5,20 @@ import string
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.api_key import ApiKey
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiKeyStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
API_KEY_PREFIX = 'sk-oh-'
|
||||
|
||||
def generate_api_key(self, length: int = 32) -> str:
|
||||
@@ -39,8 +43,22 @@ class ApiKeyStore:
|
||||
api_key = self.generate_api_key()
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id
|
||||
await call_sync_from_async(
|
||||
self._store_api_key, user_id, org_id, api_key, name, expires_at
|
||||
)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
return api_key
|
||||
|
||||
def _store_api_key(
|
||||
self,
|
||||
user_id: str,
|
||||
org_id: str,
|
||||
api_key: str,
|
||||
name: str | None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Store an existing API key in the database."""
|
||||
with self.session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
@@ -49,17 +67,14 @@ class ApiKeyStore:
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
return api_key
|
||||
|
||||
async def validate_api_key(self, api_key: str) -> str | None:
|
||||
def validate_api_key(self, api_key: str) -> str | None:
|
||||
"""Validate an API key and return the associated user_id if valid."""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
|
||||
key_record = result.scalars().first()
|
||||
with self.session_maker() as session:
|
||||
key_record = session.query(ApiKey).filter(ApiKey.key == api_key).first()
|
||||
|
||||
if not key_record:
|
||||
return None
|
||||
@@ -76,40 +91,38 @@ class ApiKeyStore:
|
||||
return None
|
||||
|
||||
# Update last_used_at timestamp
|
||||
await session.execute(
|
||||
session.execute(
|
||||
update(ApiKey)
|
||||
.where(ApiKey.id == key_record.id)
|
||||
.values(last_used_at=now)
|
||||
)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
return key_record.user_id
|
||||
|
||||
async def delete_api_key(self, api_key: str) -> bool:
|
||||
def delete_api_key(self, api_key: str) -> bool:
|
||||
"""Delete an API key by the key value."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
|
||||
key_record = result.scalars().first()
|
||||
with self.session_maker() as session:
|
||||
key_record = session.query(ApiKey).filter(ApiKey.key == api_key).first()
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
session.delete(key_record)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
async def delete_api_key_by_id(self, key_id: int) -> bool:
|
||||
def delete_api_key_by_id(self, key_id: int) -> bool:
|
||||
"""Delete an API key by its ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
with self.session_maker() as session:
|
||||
key_record = session.query(ApiKey).filter(ApiKey.id == key_id).first()
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
session.delete(key_record)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@@ -117,55 +130,64 @@ class ApiKeyStore:
|
||||
"""List all API keys for a user."""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id
|
||||
return await call_sync_from_async(self._list_api_keys_from_db, user_id, org_id)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(
|
||||
ApiKey.user_id == user_id, ApiKey.org_id == org_id
|
||||
)
|
||||
def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]:
|
||||
with self.session_maker() as session:
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
|
||||
return [key for key in keys if key.name != 'MCP_API_KEY']
|
||||
|
||||
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id
|
||||
return await call_sync_from_async(
|
||||
self._retrieve_mcp_api_key_from_db, user_id, org_id
|
||||
)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(
|
||||
ApiKey.user_id == user_id, ApiKey.org_id == org_id
|
||||
)
|
||||
def _retrieve_mcp_api_key_from_db(self, user_id: str, org_id: str) -> str | None:
|
||||
with self.session_maker() as session:
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
for key in keys:
|
||||
if key.name == 'MCP_API_KEY':
|
||||
return key.key
|
||||
|
||||
return None
|
||||
|
||||
async def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
|
||||
def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
|
||||
"""Retrieve an API key by name for a specific user."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
with self.session_maker() as session:
|
||||
key_record = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
.first()
|
||||
)
|
||||
key_record = result.scalars().first()
|
||||
return key_record.key if key_record else None
|
||||
|
||||
async def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
|
||||
def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
|
||||
"""Delete an API key by name for a specific user."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
with self.session_maker() as session:
|
||||
key_record = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
.first()
|
||||
)
|
||||
key_record = result.scalars().first()
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
session.delete(key_record)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@@ -173,4 +195,4 @@ class ApiKeyStore:
|
||||
def get_instance(cls) -> ApiKeyStore:
|
||||
"""Get an instance of the ApiKeyStore."""
|
||||
logger.debug('api_key_store.get_instance')
|
||||
return ApiKeyStore()
|
||||
return ApiKeyStore(session_maker)
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Awaitable, Callable, Dict
|
||||
from server.auth.auth_error import TokenRefreshError
|
||||
from sqlalchemy import select, text, update
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.database import a_session_maker
|
||||
|
||||
@@ -26,6 +27,7 @@ LOCK_TIMEOUT_SECONDS = 5
|
||||
class AuthTokenStore:
|
||||
keycloak_user_id: str
|
||||
idp: ProviderType
|
||||
a_session_maker: sessionmaker
|
||||
|
||||
@property
|
||||
def identity_provider_value(self) -> str:
|
||||
@@ -71,7 +73,7 @@ class AuthTokenStore:
|
||||
access_token_expires_at: Expiration time for access token (seconds since epoch)
|
||||
refresh_token_expires_at: Expiration time for refresh token (seconds since epoch)
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin(): # Explicitly start a transaction
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
@@ -136,7 +138,7 @@ class AuthTokenStore:
|
||||
a 401 response to prompt the user to re-authenticate.
|
||||
"""
|
||||
# FAST PATH: Check without lock first to avoid unnecessary lock contention
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).filter(
|
||||
AuthTokens.keycloak_user_id == self.keycloak_user_id,
|
||||
@@ -165,7 +167,7 @@ class AuthTokenStore:
|
||||
|
||||
# SLOW PATH: Token needs refresh, acquire lock
|
||||
try:
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin():
|
||||
# Set a lock timeout to prevent indefinite blocking
|
||||
# This ensures we don't hold connections forever if something goes wrong
|
||||
@@ -298,4 +300,6 @@ class AuthTokenStore:
|
||||
logger.debug(f'auth_token_store.get_instance::{keycloak_user_id}')
|
||||
if keycloak_user_id:
|
||||
keycloak_user_id = str(keycloak_user_id)
|
||||
return AuthTokenStore(keycloak_user_id=keycloak_user_id, idp=idp)
|
||||
return AuthTokenStore(
|
||||
keycloak_user_id=keycloak_user_id, idp=idp, a_session_maker=a_session_maker
|
||||
)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import text
|
||||
from storage.database import a_session_maker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockedEmailDomainStore:
|
||||
async def is_domain_blocked(self, domain: str) -> bool:
|
||||
session_maker: sessionmaker
|
||||
|
||||
def is_domain_blocked(self, domain: str) -> bool:
|
||||
"""Check if a domain is blocked by querying the database directly.
|
||||
|
||||
This method uses SQL to efficiently check if the domain matches any blocked pattern:
|
||||
@@ -19,9 +21,9 @@ class BlockedEmailDomainStore:
|
||||
Returns:
|
||||
True if the domain is blocked, False otherwise
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
with self.session_maker() as session:
|
||||
# SQL query that handles both TLD patterns and full domain patterns
|
||||
# TLD patterns (starting with '.'): check if domain ends with it (case-insensitive)
|
||||
# TLD patterns (starting with '.'): check if domain ends with the pattern
|
||||
# Full domain patterns: check for exact match or subdomain match
|
||||
# All comparisons are case-insensitive using LOWER() to ensure consistent matching
|
||||
query = text("""
|
||||
@@ -39,5 +41,5 @@ class BlockedEmailDomainStore:
|
||||
))
|
||||
)
|
||||
""")
|
||||
result = await session.execute(query, {'domain': domain})
|
||||
return bool(result.scalar())
|
||||
result = session.execute(query, {'domain': domain}).scalar()
|
||||
return bool(result)
|
||||
|
||||
@@ -47,11 +47,7 @@ class DeviceCode(Base):
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the device code has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Handle timezone-naive datetime from database by assuming it's UTC
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return now > expires_at
|
||||
return now > self.expires_at
|
||||
|
||||
def is_pending(self) -> bool:
|
||||
"""Check if the device code is still pending authorization."""
|
||||
@@ -89,13 +85,8 @@ class DeviceCode(Base):
|
||||
if self.last_poll_time is None:
|
||||
return False, self.current_interval
|
||||
|
||||
# Handle timezone-naive datetime from database by assuming it's UTC
|
||||
last_poll_time = self.last_poll_time
|
||||
if last_poll_time.tzinfo is None:
|
||||
last_poll_time = last_poll_time.replace(tzinfo=timezone.utc)
|
||||
|
||||
# Calculate time since last poll
|
||||
time_since_last_poll = (now - last_poll_time).total_seconds()
|
||||
time_since_last_poll = (now - self.last_poll_time).total_seconds()
|
||||
|
||||
# Check if polling too fast
|
||||
if time_since_last_poll < self.current_interval:
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
"""Device code store for OAuth 2.0 Device Flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.database import a_session_maker
|
||||
from storage.device_code import DeviceCode
|
||||
|
||||
|
||||
class DeviceCodeStore:
|
||||
"""Store for managing OAuth 2.0 device codes."""
|
||||
|
||||
def __init__(self, session_maker):
|
||||
self.session_maker = session_maker
|
||||
|
||||
def generate_user_code(self) -> str:
|
||||
"""Generate a human-readable user code (8 characters, uppercase letters and digits)."""
|
||||
# Use a mix of uppercase letters and digits, avoiding confusing characters
|
||||
@@ -26,7 +25,7 @@ class DeviceCodeStore:
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(128))
|
||||
|
||||
async def create_device_code(
|
||||
def create_device_code(
|
||||
self,
|
||||
expires_in: int = 600, # 10 minutes default
|
||||
max_attempts: int = 10,
|
||||
@@ -59,10 +58,11 @@ class DeviceCodeStore:
|
||||
)
|
||||
|
||||
try:
|
||||
async with a_session_maker() as session:
|
||||
with self.session_maker() as session:
|
||||
session.add(device_code_entry)
|
||||
await session.commit()
|
||||
await session.refresh(device_code_entry)
|
||||
session.commit()
|
||||
session.refresh(device_code_entry)
|
||||
session.expunge(device_code_entry) # Detach from session cleanly
|
||||
return device_code_entry
|
||||
except IntegrityError:
|
||||
# Constraint violation - codes already exist, retry with new codes
|
||||
@@ -72,23 +72,25 @@ class DeviceCodeStore:
|
||||
f'Failed to generate unique device codes after {max_attempts} attempts'
|
||||
)
|
||||
|
||||
async def get_by_device_code(self, device_code: str) -> DeviceCode | None:
|
||||
def get_by_device_code(self, device_code: str) -> DeviceCode | None:
|
||||
"""Get device code entry by device code."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(DeviceCode).filter_by(device_code=device_code)
|
||||
with self.session_maker() as session:
|
||||
result = (
|
||||
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
if result:
|
||||
session.expunge(result) # Detach from session cleanly
|
||||
return result
|
||||
|
||||
async def get_by_user_code(self, user_code: str) -> DeviceCode | None:
|
||||
def get_by_user_code(self, user_code: str) -> DeviceCode | None:
|
||||
"""Get device code entry by user code."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(DeviceCode).filter_by(user_code=user_code)
|
||||
)
|
||||
return result.scalars().first()
|
||||
with self.session_maker() as session:
|
||||
result = session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
if result:
|
||||
session.expunge(result) # Detach from session cleanly
|
||||
return result
|
||||
|
||||
async def authorize_device_code(self, user_code: str, user_id: str) -> bool:
|
||||
def authorize_device_code(self, user_code: str, user_id: str) -> bool:
|
||||
"""Authorize a device code.
|
||||
|
||||
Args:
|
||||
@@ -98,11 +100,10 @@ class DeviceCodeStore:
|
||||
Returns:
|
||||
True if authorization was successful, False otherwise
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(DeviceCode).filter_by(user_code=user_code)
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
)
|
||||
device_code_entry = result.scalars().first()
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
@@ -111,11 +112,11 @@ class DeviceCodeStore:
|
||||
return False
|
||||
|
||||
device_code_entry.authorize(user_id)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
async def deny_device_code(self, user_code: str) -> bool:
|
||||
def deny_device_code(self, user_code: str) -> bool:
|
||||
"""Deny a device code authorization.
|
||||
|
||||
Args:
|
||||
@@ -124,11 +125,10 @@ class DeviceCodeStore:
|
||||
Returns:
|
||||
True if denial was successful, False otherwise
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(DeviceCode).filter_by(user_code=user_code)
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
)
|
||||
device_code_entry = result.scalars().first()
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
@@ -137,11 +137,11 @@ class DeviceCodeStore:
|
||||
return False
|
||||
|
||||
device_code_entry.deny()
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
async def update_poll_time(
|
||||
def update_poll_time(
|
||||
self, device_code: str, increase_interval: bool = False
|
||||
) -> bool:
|
||||
"""Update the poll time for a device code and optionally increase interval.
|
||||
@@ -153,16 +153,15 @@ class DeviceCodeStore:
|
||||
Returns:
|
||||
True if update was successful, False otherwise
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(DeviceCode).filter_by(device_code=device_code)
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
||||
)
|
||||
device_code_entry = result.scalars().first()
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
|
||||
device_code_entry.update_poll_time(increase_interval)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
from integrations.types import GitLabResourceType
|
||||
from sqlalchemy import and_, asc, select, text, update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import a_session_maker
|
||||
from storage.gitlab_webhook import GitlabWebhook
|
||||
|
||||
@@ -13,6 +14,8 @@ from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@dataclass
|
||||
class GitlabWebhookStore:
|
||||
a_session_maker: sessionmaker = a_session_maker
|
||||
|
||||
@staticmethod
|
||||
def determine_resource_type(
|
||||
webhook: GitlabWebhook,
|
||||
@@ -41,7 +44,7 @@ class GitlabWebhookStore:
|
||||
if not project_details:
|
||||
return
|
||||
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin():
|
||||
# Convert GitlabWebhook objects to dictionaries for the insert
|
||||
# Using __dict__ and filtering out SQLAlchemy internal attributes and 'id'
|
||||
@@ -85,7 +88,7 @@ class GitlabWebhookStore:
|
||||
"""
|
||||
|
||||
resource_type, resource_id = GitlabWebhookStore.determine_resource_type(webhook)
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
update(GitlabWebhook).where(GitlabWebhook.project_id == resource_id)
|
||||
@@ -119,7 +122,7 @@ class GitlabWebhookStore:
|
||||
},
|
||||
)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin():
|
||||
# Create query based on the identifier provided
|
||||
if resource_type == GitLabResourceType.PROJECT:
|
||||
@@ -182,7 +185,7 @@ class GitlabWebhookStore:
|
||||
List of GitlabWebhook objects that need processing
|
||||
"""
|
||||
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
query = (
|
||||
select(GitlabWebhook)
|
||||
.where(GitlabWebhook.webhook_exists.is_(False))
|
||||
@@ -198,7 +201,7 @@ class GitlabWebhookStore:
|
||||
"""
|
||||
Get's webhook secret given the webhook uuid and admin keycloak user id
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
query = (
|
||||
select(GitlabWebhook)
|
||||
.where(
|
||||
@@ -232,7 +235,7 @@ class GitlabWebhookStore:
|
||||
Returns:
|
||||
GitlabWebhook object if found, None otherwise
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
if resource_type == GitLabResourceType.PROJECT:
|
||||
query = select(GitlabWebhook).where(
|
||||
GitlabWebhook.project_id == resource_id
|
||||
@@ -260,7 +263,7 @@ class GitlabWebhookStore:
|
||||
Returns:
|
||||
Tuple of (project_webhook_map, group_webhook_map)
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
project_webhook_map = {}
|
||||
group_webhook_map = {}
|
||||
|
||||
@@ -300,7 +303,7 @@ class GitlabWebhookStore:
|
||||
Returns:
|
||||
True if webhook was reset, False if not found
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin():
|
||||
if resource_type == GitLabResourceType.PROJECT:
|
||||
update_statement = (
|
||||
@@ -345,4 +348,4 @@ class GitlabWebhookStore:
|
||||
Returns:
|
||||
An instance of GitlabWebhookStore
|
||||
"""
|
||||
return GitlabWebhookStore()
|
||||
return GitlabWebhookStore(a_session_maker)
|
||||
|
||||
@@ -3,8 +3,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.jira_dc_conversation import JiraDcConversation
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
@@ -25,7 +24,7 @@ class JiraDcIntegrationStore:
|
||||
) -> JiraDcWorkspace:
|
||||
"""Create a new Jira DC workspace with encrypted sensitive data."""
|
||||
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
workspace = JiraDcWorkspace(
|
||||
name=name.lower(),
|
||||
admin_user_id=admin_user_id,
|
||||
@@ -35,8 +34,8 @@ class JiraDcIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
session.add(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
logger.info(f'[Jira DC] Created workspace {workspace.name}')
|
||||
return workspace
|
||||
|
||||
@@ -49,12 +48,11 @@ class JiraDcIntegrationStore:
|
||||
status: Optional[str] = None,
|
||||
) -> JiraDcWorkspace:
|
||||
"""Update an existing Jira DC workspace with encrypted sensitive data."""
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
# Find existing workspace by ID
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(JiraDcWorkspace.id == id)
|
||||
workspace = (
|
||||
session.query(JiraDcWorkspace).filter(JiraDcWorkspace.id == id).first()
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
|
||||
if not workspace:
|
||||
raise ValueError(f'Workspace with ID "{id}" not found')
|
||||
@@ -71,8 +69,8 @@ class JiraDcIntegrationStore:
|
||||
if status is not None:
|
||||
workspace.status = status
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Jira DC] Updated workspace {workspace.name}')
|
||||
return workspace
|
||||
@@ -93,10 +91,10 @@ class JiraDcIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
session.add(jira_dc_user)
|
||||
await session.commit()
|
||||
await session.refresh(jira_dc_user)
|
||||
session.commit()
|
||||
session.refresh(jira_dc_user)
|
||||
|
||||
logger.info(
|
||||
f'[Jira DC] Created user {jira_dc_user.id} for workspace {jira_dc_workspace_id}'
|
||||
@@ -105,91 +103,94 @@ class JiraDcIntegrationStore:
|
||||
|
||||
async def get_workspace_by_id(self, workspace_id: int) -> Optional[JiraDcWorkspace]:
|
||||
"""Retrieve workspace by ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(JiraDcWorkspace.id == workspace_id)
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcWorkspace)
|
||||
.filter(JiraDcWorkspace.id == workspace_id)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_workspace_by_name(
|
||||
self, workspace_name: str
|
||||
) -> Optional[JiraDcWorkspace]:
|
||||
"""Retrieve workspace by name."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(
|
||||
JiraDcWorkspace.name == workspace_name.lower()
|
||||
)
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcWorkspace)
|
||||
.filter(JiraDcWorkspace.name == workspace_name.lower())
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_active_workspace(
|
||||
self, keycloak_user_id: str
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Retrieve user by Keycloak user ID."""
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_keycloak_id_and_workspace(
|
||||
self, keycloak_user_id: str, jira_dc_workspace_id: int
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Get Jira DC user by Keycloak user ID and workspace ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_user(
|
||||
self, jira_dc_user_id: str, jira_dc_workspace_id: int
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Get Jira DC user by Keycloak user ID and workspace ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
JiraDcUser.jira_dc_user_id == jira_dc_user_id,
|
||||
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_user_by_keycloak_id_and_workspace(
|
||||
self, keycloak_user_id: str, jira_dc_workspace_id: int
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Get Jira DC user by Keycloak user ID and workspace ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_user_integration_status(
|
||||
self, keycloak_user_id: str, status: str
|
||||
) -> JiraDcUser:
|
||||
"""Update the status of a Jira DC user mapping."""
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id
|
||||
)
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(JiraDcUser)
|
||||
.filter(JiraDcUser.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise ValueError(
|
||||
@@ -197,35 +198,37 @@ class JiraDcIntegrationStore:
|
||||
)
|
||||
|
||||
user.status = status
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
logger.info(f'[Jira DC] Updated user {keycloak_user_id} status to {status}')
|
||||
return user
|
||||
|
||||
async def deactivate_workspace(self, workspace_id: int):
|
||||
"""Deactivate the workspace and all user links for a given workspace."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
with session_maker() as session:
|
||||
users = (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
JiraDcUser.jira_dc_workspace_id == workspace_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.all()
|
||||
)
|
||||
users = result.scalars().all()
|
||||
|
||||
for user in users:
|
||||
user.status = 'inactive'
|
||||
session.add(user)
|
||||
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(JiraDcWorkspace.id == workspace_id)
|
||||
workspace = (
|
||||
session.query(JiraDcWorkspace)
|
||||
.filter(JiraDcWorkspace.id == workspace_id)
|
||||
.first()
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
if workspace:
|
||||
workspace.status = 'inactive'
|
||||
session.add(workspace)
|
||||
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
f'[Jira DC] Deactivated all user links for workspace {workspace_id}'
|
||||
@@ -235,22 +238,23 @@ class JiraDcIntegrationStore:
|
||||
self, jira_dc_conversation: JiraDcConversation
|
||||
) -> None:
|
||||
"""Create a new Jira DC conversation record."""
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
session.add(jira_dc_conversation)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
async def get_user_conversations_by_issue_id(
|
||||
self, issue_id: str, jira_dc_user_id: int
|
||||
) -> JiraDcConversation | None:
|
||||
"""Get a Jira DC conversation by issue ID and jira dc user ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcConversation).where(
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcConversation)
|
||||
.filter(
|
||||
JiraDcConversation.issue_id == issue_id,
|
||||
JiraDcConversation.jira_dc_user_id == jira_dc_user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> JiraDcIntegrationStore:
|
||||
|
||||
@@ -3,8 +3,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
@@ -36,10 +35,10 @@ class JiraIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
session.add(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Jira] Created workspace {workspace.name}')
|
||||
return workspace
|
||||
@@ -54,12 +53,11 @@ class JiraIntegrationStore:
|
||||
status: Optional[str] = None,
|
||||
) -> JiraWorkspace:
|
||||
"""Update an existing Jira workspace with encrypted sensitive data."""
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
# Find existing workspace by ID
|
||||
result = await session.execute(
|
||||
select(JiraWorkspace).filter(JiraWorkspace.id == id)
|
||||
workspace = (
|
||||
session.query(JiraWorkspace).filter(JiraWorkspace.id == id).first()
|
||||
)
|
||||
workspace = result.scalars().first()
|
||||
|
||||
if not workspace:
|
||||
raise ValueError(f'Workspace with ID "{id}" not found')
|
||||
@@ -79,11 +77,11 @@ class JiraIntegrationStore:
|
||||
if status is not None:
|
||||
workspace.status = status
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Jira] Updated workspace {workspace.name}')
|
||||
return workspace
|
||||
logger.info(f'[Jira] Updated workspace {workspace.name}')
|
||||
return workspace
|
||||
|
||||
async def create_workspace_link(
|
||||
self,
|
||||
@@ -101,10 +99,10 @@ class JiraIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
session.add(jira_user)
|
||||
await session.commit()
|
||||
await session.refresh(jira_user)
|
||||
session.commit()
|
||||
session.refresh(jira_user)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Created user {jira_user.id} for workspace {jira_workspace_id}'
|
||||
@@ -113,77 +111,75 @@ class JiraIntegrationStore:
|
||||
|
||||
async def get_workspace_by_id(self, workspace_id: int) -> Optional[JiraWorkspace]:
|
||||
"""Retrieve workspace by ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraWorkspace).filter(JiraWorkspace.id == workspace_id)
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraWorkspace)
|
||||
.filter(JiraWorkspace.id == workspace_id)
|
||||
.first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_workspace_by_name(self, workspace_name: str) -> JiraWorkspace | None:
|
||||
"""Retrieve workspace by name."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraWorkspace).filter(
|
||||
JiraWorkspace.name == workspace_name.lower()
|
||||
)
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraWorkspace)
|
||||
.filter(JiraWorkspace.name == workspace_name.lower())
|
||||
.first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_by_active_workspace(
|
||||
self, keycloak_user_id: str
|
||||
) -> Optional[JiraUser]:
|
||||
"""Get Jira user by Keycloak user ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraUser).filter(
|
||||
and_(
|
||||
JiraUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraUser.status == 'active',
|
||||
)
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraUser)
|
||||
.filter(
|
||||
JiraUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_by_keycloak_id_and_workspace(
|
||||
self, keycloak_user_id: str, jira_workspace_id: int
|
||||
) -> Optional[JiraUser]:
|
||||
"""Get Jira user by Keycloak user ID and workspace ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraUser).filter(
|
||||
and_(
|
||||
JiraUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraUser.jira_workspace_id == jira_workspace_id,
|
||||
)
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraUser)
|
||||
.filter(
|
||||
JiraUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraUser.jira_workspace_id == jira_workspace_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_active_user(
|
||||
self, jira_user_id: str, jira_workspace_id: int
|
||||
) -> Optional[JiraUser]:
|
||||
"""Get Jira user by Keycloak user ID and workspace ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraUser).filter(
|
||||
and_(
|
||||
JiraUser.jira_user_id == jira_user_id,
|
||||
JiraUser.jira_workspace_id == jira_workspace_id,
|
||||
JiraUser.status == 'active',
|
||||
)
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraUser)
|
||||
.filter(
|
||||
JiraUser.jira_user_id == jira_user_id,
|
||||
JiraUser.jira_workspace_id == jira_workspace_id,
|
||||
JiraUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def update_user_integration_status(
|
||||
self, keycloak_user_id: str, status: str
|
||||
) -> JiraUser:
|
||||
"""Update Jira user integration status."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraUser).filter(JiraUser.keycloak_user_id == keycloak_user_id)
|
||||
with session_maker() as session:
|
||||
jira_user = (
|
||||
session.query(JiraUser)
|
||||
.filter(JiraUser.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
)
|
||||
jira_user = result.scalars().first()
|
||||
|
||||
if not jira_user:
|
||||
raise ValueError(
|
||||
@@ -191,61 +187,60 @@ class JiraIntegrationStore:
|
||||
)
|
||||
|
||||
jira_user.status = status
|
||||
await session.commit()
|
||||
await session.refresh(jira_user)
|
||||
session.commit()
|
||||
session.refresh(jira_user)
|
||||
|
||||
logger.info(f'[Jira] Updated user {keycloak_user_id} status to {status}')
|
||||
return jira_user
|
||||
|
||||
async def deactivate_workspace(self, workspace_id: int):
|
||||
"""Deactivate the workspace and all user links for a given workspace."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraUser).filter(
|
||||
and_(
|
||||
JiraUser.jira_workspace_id == workspace_id,
|
||||
JiraUser.status == 'active',
|
||||
)
|
||||
with session_maker() as session:
|
||||
users = (
|
||||
session.query(JiraUser)
|
||||
.filter(
|
||||
JiraUser.jira_workspace_id == workspace_id,
|
||||
JiraUser.status == 'active',
|
||||
)
|
||||
.all()
|
||||
)
|
||||
users = result.scalars().all()
|
||||
|
||||
for user in users:
|
||||
user.status = 'inactive'
|
||||
session.add(user)
|
||||
|
||||
result = await session.execute(
|
||||
select(JiraWorkspace).filter(JiraWorkspace.id == workspace_id)
|
||||
workspace = (
|
||||
session.query(JiraWorkspace)
|
||||
.filter(JiraWorkspace.id == workspace_id)
|
||||
.first()
|
||||
)
|
||||
workspace = result.scalars().first()
|
||||
if workspace:
|
||||
workspace.status = 'inactive'
|
||||
session.add(workspace)
|
||||
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
logger.info(f'[Jira] Deactivated all user links for workspace {workspace_id}')
|
||||
|
||||
async def create_conversation(self, jira_conversation: JiraConversation) -> None:
|
||||
"""Create a new Jira conversation record."""
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
session.add(jira_conversation)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
async def get_user_conversations_by_issue_id(
|
||||
self, issue_id: str, jira_user_id: int
|
||||
) -> JiraConversation | None:
|
||||
"""Get a Jira conversation by issue ID and jira user ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraConversation).filter(
|
||||
and_(
|
||||
JiraConversation.issue_id == issue_id,
|
||||
JiraConversation.jira_user_id == jira_user_id,
|
||||
)
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraConversation)
|
||||
.filter(
|
||||
JiraConversation.issue_id == issue_id,
|
||||
JiraConversation.jira_user_id == jira_user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> JiraIntegrationStore:
|
||||
|
||||
@@ -3,8 +3,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.linear_conversation import LinearConversation
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
@@ -36,10 +35,10 @@ class LinearIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
session.add(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Linear] Created workspace {workspace.name}')
|
||||
return workspace
|
||||
@@ -54,12 +53,11 @@ class LinearIntegrationStore:
|
||||
status: Optional[str] = None,
|
||||
) -> LinearWorkspace:
|
||||
"""Update an existing Linear workspace with encrypted sensitive data."""
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
# Find existing workspace by ID
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(LinearWorkspace.id == id)
|
||||
workspace = (
|
||||
session.query(LinearWorkspace).filter(LinearWorkspace.id == id).first()
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
|
||||
if not workspace:
|
||||
raise ValueError(f'Workspace with ID "{id}" not found')
|
||||
@@ -79,8 +77,8 @@ class LinearIntegrationStore:
|
||||
if status is not None:
|
||||
workspace.status = status
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Linear] Updated workspace {workspace.name}')
|
||||
return workspace
|
||||
@@ -100,10 +98,10 @@ class LinearIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
session.add(linear_user)
|
||||
await session.commit()
|
||||
await session.refresh(linear_user)
|
||||
session.commit()
|
||||
session.refresh(linear_user)
|
||||
|
||||
logger.info(
|
||||
f'[Linear] Created user {linear_user.id} for workspace {linear_workspace_id}'
|
||||
@@ -112,75 +110,77 @@ class LinearIntegrationStore:
|
||||
|
||||
async def get_workspace_by_id(self, workspace_id: int) -> Optional[LinearWorkspace]:
|
||||
"""Retrieve workspace by ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(LinearWorkspace.id == workspace_id)
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearWorkspace)
|
||||
.filter(LinearWorkspace.id == workspace_id)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_workspace_by_name(
|
||||
self, workspace_name: str
|
||||
) -> Optional[LinearWorkspace]:
|
||||
"""Retrieve workspace by name."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(
|
||||
LinearWorkspace.name == workspace_name.lower()
|
||||
)
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearWorkspace)
|
||||
.filter(LinearWorkspace.name == workspace_name.lower())
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_active_workspace(
|
||||
self, keycloak_user_id: str
|
||||
) -> LinearUser | None:
|
||||
"""Get Linear user by Keycloak user ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
LinearUser.keycloak_user_id == keycloak_user_id,
|
||||
LinearUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_keycloak_id_and_workspace(
|
||||
self, keycloak_user_id: str, linear_workspace_id: int
|
||||
) -> Optional[LinearUser]:
|
||||
"""Get Linear user by Keycloak user ID and workspace ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
LinearUser.keycloak_user_id == keycloak_user_id,
|
||||
LinearUser.linear_workspace_id == linear_workspace_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_user(
|
||||
self, linear_user_id: str, linear_workspace_id: int
|
||||
) -> Optional[LinearUser]:
|
||||
"""Get Linear user by Keycloak user ID and workspace ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
LinearUser.linear_user_id == linear_user_id,
|
||||
LinearUser.linear_workspace_id == linear_workspace_id,
|
||||
LinearUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_user_integration_status(
|
||||
self, keycloak_user_id: str, status: str
|
||||
) -> LinearUser:
|
||||
"""Update Linear user integration status."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
LinearUser.keycloak_user_id == keycloak_user_id
|
||||
)
|
||||
with session_maker() as session:
|
||||
linear_user = (
|
||||
session.query(LinearUser)
|
||||
.filter(LinearUser.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
)
|
||||
linear_user = result.scalar_one_or_none()
|
||||
|
||||
if not linear_user:
|
||||
raise ValueError(
|
||||
@@ -188,36 +188,38 @@ class LinearIntegrationStore:
|
||||
)
|
||||
|
||||
linear_user.status = status
|
||||
await session.commit()
|
||||
await session.refresh(linear_user)
|
||||
session.commit()
|
||||
session.refresh(linear_user)
|
||||
|
||||
logger.info(f'[Linear] Updated user {keycloak_user_id} status to {status}')
|
||||
return linear_user
|
||||
|
||||
async def deactivate_workspace(self, workspace_id: int):
|
||||
"""Deactivate the workspace and all user links for a given workspace."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
with session_maker() as session:
|
||||
users = (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
LinearUser.linear_workspace_id == workspace_id,
|
||||
LinearUser.status == 'active',
|
||||
)
|
||||
.all()
|
||||
)
|
||||
users = result.scalars().all()
|
||||
|
||||
for user in users:
|
||||
user.status = 'inactive'
|
||||
session.add(user)
|
||||
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(LinearWorkspace.id == workspace_id)
|
||||
workspace = (
|
||||
session.query(LinearWorkspace)
|
||||
.filter(LinearWorkspace.id == workspace_id)
|
||||
.first()
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
if workspace:
|
||||
workspace.status = 'inactive'
|
||||
session.add(workspace)
|
||||
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
logger.info(f'[Jira] Deactivated all user links for workspace {workspace_id}')
|
||||
|
||||
@@ -225,22 +227,23 @@ class LinearIntegrationStore:
|
||||
self, linear_conversation: LinearConversation
|
||||
) -> None:
|
||||
"""Create a new Linear conversation record."""
|
||||
async with a_session_maker() as session:
|
||||
with session_maker() as session:
|
||||
session.add(linear_conversation)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
async def get_user_conversations_by_issue_id(
|
||||
self, issue_id: str, linear_user_id: int
|
||||
) -> LinearConversation | None:
|
||||
"""Get a Linear conversation by issue ID and linear user ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearConversation).where(
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearConversation)
|
||||
.filter(
|
||||
LinearConversation.issue_id == issue_id,
|
||||
LinearConversation.linear_user_id == linear_user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> LinearIntegrationStore:
|
||||
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -13,17 +13,17 @@ from openhands.core.logger import openhands_logger as logger
|
||||
@dataclass
|
||||
class OfflineTokenStore:
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
|
||||
async def store_token(self, offline_token: str) -> None:
|
||||
"""Store an offline token in the database."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredOfflineToken).where(
|
||||
StoredOfflineToken.user_id == self.user_id
|
||||
)
|
||||
with self.session_maker() as session:
|
||||
token_record = (
|
||||
session.query(StoredOfflineToken)
|
||||
.filter(StoredOfflineToken.user_id == self.user_id)
|
||||
.first()
|
||||
)
|
||||
token_record = result.scalar_one_or_none()
|
||||
|
||||
if token_record:
|
||||
token_record.offline_token = offline_token
|
||||
@@ -32,17 +32,16 @@ class OfflineTokenStore:
|
||||
user_id=self.user_id, offline_token=offline_token
|
||||
)
|
||||
session.add(token_record)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
async def load_token(self) -> str | None:
|
||||
"""Load an offline token from the database."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredOfflineToken).where(
|
||||
StoredOfflineToken.user_id == self.user_id
|
||||
)
|
||||
with self.session_maker() as session:
|
||||
token_record = (
|
||||
session.query(StoredOfflineToken)
|
||||
.filter(StoredOfflineToken.user_id == self.user_id)
|
||||
.first()
|
||||
)
|
||||
token_record = result.scalar_one_or_none()
|
||||
|
||||
if not token_record:
|
||||
return None
|
||||
@@ -57,4 +56,4 @@ class OfflineTokenStore:
|
||||
logger.debug(f'offline_token_store.get_instance::{user_id}')
|
||||
if user_id:
|
||||
user_id = str(user_id)
|
||||
return OfflineTokenStore(user_id, config)
|
||||
return OfflineTokenStore(user_id, session_maker, config)
|
||||
|
||||
@@ -3,7 +3,6 @@ Service class for managing organization operations.
|
||||
Separates business logic from route handlers.
|
||||
"""
|
||||
|
||||
from typing import NoReturn
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
@@ -326,7 +325,7 @@ class OrgService:
|
||||
user_id: str,
|
||||
original_error: Exception,
|
||||
error_message: str,
|
||||
) -> NoReturn:
|
||||
) -> None:
|
||||
"""
|
||||
Handle failure by cleaning up LiteLLM resources and raising appropriate error.
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from integrations.github.github_types import (
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import a_session_maker
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
|
||||
@@ -19,6 +20,8 @@ from openhands.integrations.service_types import ProviderType
|
||||
|
||||
@dataclass
|
||||
class ProactiveConversationStore:
|
||||
a_session_maker: sessionmaker = a_session_maker
|
||||
|
||||
def get_repo_id(self, provider: ProviderType, repo_id):
|
||||
return f'{provider.value}##{repo_id}'
|
||||
|
||||
@@ -48,7 +51,7 @@ class ProactiveConversationStore:
|
||||
|
||||
final_workflow_group = None
|
||||
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
# Start an explicit transaction with row-level locking
|
||||
async with session.begin():
|
||||
# Get the existing proactive conversation entry with FOR UPDATE lock
|
||||
@@ -139,7 +142,7 @@ class ProactiveConversationStore:
|
||||
# Calculate the cutoff time (current time - older_than_minutes)
|
||||
cutoff_time = datetime.now(UTC) - timedelta(minutes=older_than_minutes)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin():
|
||||
# Delete records older than the cutoff time
|
||||
delete_stmt = delete(ProactiveConversation).where(
|
||||
@@ -155,9 +158,9 @@ class ProactiveConversationStore:
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> ProactiveConversationStore:
|
||||
"""Get an instance of the ProactiveConversationStore.
|
||||
"""Get an instance of the GitlabWebhookStore.
|
||||
|
||||
Returns:
|
||||
An instance of ProactiveConversationStore
|
||||
An instance of GitlabWebhookStore
|
||||
"""
|
||||
return ProactiveConversationStore()
|
||||
return ProactiveConversationStore(a_session_maker)
|
||||
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_repository import StoredRepository
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -11,11 +11,12 @@ from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
|
||||
@dataclass
|
||||
class RepositoryStore:
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
|
||||
async def store_projects(self, repositories: list[StoredRepository]) -> None:
|
||||
def store_projects(self, repositories: list[StoredRepository]) -> None:
|
||||
"""
|
||||
Store repositories in database (async version)
|
||||
Store repositories in database
|
||||
|
||||
1. Make sure to store repositories if its ID doesn't exist
|
||||
2. If repository ID already exists, make sure to only update the repo is_public and repo_name fields
|
||||
@@ -25,15 +26,17 @@ class RepositoryStore:
|
||||
if not repositories:
|
||||
return
|
||||
|
||||
async with a_session_maker() as session:
|
||||
with self.session_maker() as session:
|
||||
# Extract all repo_ids to check
|
||||
repo_ids = [r.repo_id for r in repositories]
|
||||
|
||||
# Get all existing repositories in a single query
|
||||
result = await session.execute(
|
||||
select(StoredRepository).filter(StoredRepository.repo_id.in_(repo_ids))
|
||||
)
|
||||
existing_repos = {r.repo_id: r for r in result.scalars().all()}
|
||||
existing_repos = {
|
||||
r.repo_id: r
|
||||
for r in session.query(StoredRepository).filter(
|
||||
StoredRepository.repo_id.in_(repo_ids)
|
||||
)
|
||||
}
|
||||
|
||||
# Process all repositories
|
||||
for repo in repositories:
|
||||
@@ -47,9 +50,9 @@ class RepositoryStore:
|
||||
session.add(repo)
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, config: OpenHandsConfig) -> RepositoryStore:
|
||||
"""Get an instance of the UserRepositoryStore."""
|
||||
return RepositoryStore(config)
|
||||
return RepositoryStore(session_maker, config)
|
||||
|
||||
@@ -234,8 +234,6 @@ class SaasConversationStore(ConversationStore):
|
||||
cls, config: OpenHandsConfig, user_id: str | None
|
||||
) -> ConversationStore:
|
||||
# user_id should not be None in SaaS, should we raise?
|
||||
# Use async version since callers now use asyncio.run_coroutine_threadsafe()
|
||||
# to dispatch to the main event loop where asyncpg connections work properly.
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
return SaasConversationStore(str(user_id), org_id, session_maker)
|
||||
|
||||
@@ -28,7 +28,7 @@ class SaasConversationValidator(ConversationValidator):
|
||||
|
||||
# Validate the API key and get the user_id
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
user_id = api_key_store.validate_api_key(api_key)
|
||||
|
||||
if not user_id:
|
||||
logger.warning('Invalid API key')
|
||||
|
||||
@@ -5,8 +5,8 @@ from base64 import b64decode, b64encode
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import delete, select
|
||||
from storage.database import a_session_maker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.user_store import UserStore
|
||||
|
||||
@@ -19,6 +19,7 @@ from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
@dataclass
|
||||
class SaasSecretsStore(SecretsStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
|
||||
async def load(self) -> Secrets | None:
|
||||
@@ -27,15 +28,14 @@ class SaasSecretsStore(SecretsStore):
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
|
||||
async with a_session_maker() as session:
|
||||
with self.session_maker() as session:
|
||||
# Fetch all secrets for the given user ID
|
||||
query = select(StoredCustomSecrets).filter(
|
||||
query = session.query(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
if org_id is not None:
|
||||
query = query.filter(StoredCustomSecrets.org_id == org_id)
|
||||
result = await session.execute(query)
|
||||
settings = result.scalars().all()
|
||||
settings = query.all()
|
||||
|
||||
if not settings:
|
||||
return Secrets()
|
||||
@@ -54,15 +54,12 @@ class SaasSecretsStore(SecretsStore):
|
||||
async def store(self, item: Secrets):
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
org_id = user.current_org_id
|
||||
|
||||
async with a_session_maker() as session:
|
||||
with self.session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete all existing records and override with incoming ones
|
||||
await session.execute(
|
||||
delete(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
)
|
||||
session.query(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
).delete()
|
||||
|
||||
# Prepare the new secrets data
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
@@ -92,7 +89,7 @@ class SaasSecretsStore(SecretsStore):
|
||||
)
|
||||
session.add(new_secret)
|
||||
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
def _decrypt_kwargs(self, kwargs: dict):
|
||||
fernet = self._fernet()
|
||||
@@ -136,4 +133,4 @@ class SaasSecretsStore(SecretsStore):
|
||||
if not user_id:
|
||||
raise Exception('SaasSecretsStore cannot be constructed with no user_id')
|
||||
logger.debug(f'saas_secrets_store.get_instance::{user_id}')
|
||||
return SaasSecretsStore(user_id, config)
|
||||
return SaasSecretsStore(user_id, session_maker, config)
|
||||
|
||||
@@ -10,9 +10,8 @@ from cryptography.fernet import Fernet
|
||||
from pydantic import SecretStr
|
||||
from server.constants import LITE_LLM_API_URL
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker
|
||||
from sqlalchemy.orm import joinedload, sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager, get_openhands_cloud_key_alias
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
@@ -24,24 +23,26 @@ from storage.user_store import UserStore
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.llm import is_openhands_model
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaasSettingsStore(SettingsStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
|
||||
|
||||
async def _get_user_settings_by_keycloak_id_async(
|
||||
def _get_user_settings_by_keycloak_id(
|
||||
self, keycloak_user_id: str, session=None
|
||||
) -> UserSettings | None:
|
||||
"""
|
||||
Get UserSettings by keycloak_user_id (async version).
|
||||
Get UserSettings by keycloak_user_id.
|
||||
|
||||
Args:
|
||||
keycloak_user_id: The keycloak user ID to search for
|
||||
session: Optional existing async database session. If not provided, creates a new one.
|
||||
session: Optional existing database session. If not provided, creates a new one.
|
||||
|
||||
Returns:
|
||||
UserSettings object if found, None otherwise
|
||||
@@ -49,26 +50,27 @@ class SaasSettingsStore(SettingsStore):
|
||||
if not keycloak_user_id:
|
||||
return None
|
||||
|
||||
if session:
|
||||
# Use provided session
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == keycloak_user_id
|
||||
def _get_settings():
|
||||
if session:
|
||||
# Use provided session
|
||||
return (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
else:
|
||||
# Create new session
|
||||
async with a_session_maker() as new_session:
|
||||
result = await new_session.execute(
|
||||
select(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == keycloak_user_id
|
||||
else:
|
||||
# Create new session
|
||||
with self.session_maker() as new_session:
|
||||
return (
|
||||
new_session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
return _get_settings()
|
||||
|
||||
async def load(self) -> Settings | None:
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
if not user:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
return None
|
||||
@@ -81,7 +83,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = await OrgStore.get_org_by_id_async(org_id)
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
@@ -120,22 +122,21 @@ class SaasSettingsStore(SettingsStore):
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
async with a_session_maker() as session:
|
||||
with self.session_maker() as session:
|
||||
if not item:
|
||||
return None
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(self.user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
# Check if we need to migrate from user_settings
|
||||
user_settings = None
|
||||
async with a_session_maker() as new_session:
|
||||
user_settings = await self._get_user_settings_by_keycloak_id_async(
|
||||
self.user_id, new_session
|
||||
with session_maker() as session:
|
||||
user_settings = self._get_user_settings_by_keycloak_id(
|
||||
self.user_id, session
|
||||
)
|
||||
if user_settings:
|
||||
user = await UserStore.migrate_user(self.user_id, user_settings)
|
||||
@@ -153,8 +154,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
|
||||
result = await session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
org: Org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
@@ -173,7 +173,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
@@ -182,7 +182,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
user_id: str, # type: ignore[override]
|
||||
) -> SaasSettingsStore:
|
||||
logger.debug(f'saas_settings_store.get_instance::{user_id}')
|
||||
return SaasSettingsStore(user_id, config)
|
||||
return SaasSettingsStore(user_id, session_maker, config)
|
||||
|
||||
def _should_encrypt(self, key):
|
||||
return key in self.ENCRYPT_VALUES
|
||||
|
||||
@@ -2,35 +2,38 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.slack_conversation import SlackConversation
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackConversationStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
async def get_slack_conversation(
|
||||
self, channel_id: str, parent_id: str
|
||||
) -> SlackConversation | None:
|
||||
"""Get a slack conversation by channel_id and message_ts.
|
||||
Both parameters are required to match for a conversation to be returned.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SlackConversation).where(
|
||||
SlackConversation.channel_id == channel_id,
|
||||
SlackConversation.parent_id == parent_id,
|
||||
)
|
||||
with session_maker() as session:
|
||||
conversation = (
|
||||
session.query(SlackConversation)
|
||||
.filter(SlackConversation.channel_id == channel_id)
|
||||
.filter(SlackConversation.parent_id == parent_id)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
return conversation
|
||||
|
||||
async def create_slack_conversation(
|
||||
self, slack_converstion: SlackConversation
|
||||
) -> None:
|
||||
async with a_session_maker() as session:
|
||||
with self.session_maker() as session:
|
||||
session.merge(slack_converstion)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> SlackConversationStore:
|
||||
return SlackConversationStore()
|
||||
return SlackConversationStore(session_maker)
|
||||
|
||||
@@ -32,7 +32,6 @@ class SlackTeamStore:
|
||||
# Store the token
|
||||
session.add(slack_team)
|
||||
session.commit()
|
||||
return slack_team
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -12,11 +12,12 @@ from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
|
||||
@dataclass
|
||||
class UserRepositoryMapStore:
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
|
||||
async def store_user_repo_mappings(self, mappings: list[UserRepositoryMap]) -> None:
|
||||
def store_user_repo_mappings(self, mappings: list[UserRepositoryMap]) -> None:
|
||||
"""
|
||||
Store user-repository mappings in database (async version)
|
||||
Store user-repository mappings in database
|
||||
|
||||
1. Make sure to store mappings if they don't exist
|
||||
2. If a mapping already exists (same user_id and repo_id), update the admin field
|
||||
@@ -29,20 +30,18 @@ class UserRepositoryMapStore:
|
||||
if not mappings:
|
||||
return
|
||||
|
||||
async with a_session_maker() as session:
|
||||
with self.session_maker() as session:
|
||||
# Extract all user_id/repo_id pairs to check
|
||||
mapping_keys = [(m.user_id, m.repo_id) for m in mappings]
|
||||
|
||||
# Get all existing mappings in a single query
|
||||
result = await session.execute(
|
||||
select(UserRepositoryMap).filter(
|
||||
existing_mappings = {
|
||||
(m.user_id, m.repo_id): m
|
||||
for m in session.query(UserRepositoryMap).filter(
|
||||
sqlalchemy.tuple_(
|
||||
UserRepositoryMap.user_id, UserRepositoryMap.repo_id
|
||||
).in_(mapping_keys)
|
||||
)
|
||||
)
|
||||
existing_mappings = {
|
||||
(m.user_id, m.repo_id): m for m in result.scalars().all()
|
||||
}
|
||||
|
||||
# Process all mappings
|
||||
@@ -57,9 +56,9 @@ class UserRepositoryMapStore:
|
||||
session.add(mapping)
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, config: OpenHandsConfig) -> UserRepositoryMapStore:
|
||||
"""Get an instance of the UserRepositoryMapStore."""
|
||||
return UserRepositoryMapStore(config)
|
||||
return UserRepositoryMapStore(session_maker, config)
|
||||
|
||||
@@ -227,7 +227,7 @@ class UserStore:
|
||||
'user_store:migrate_user:calling_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await migrate_customer(user_id, org)
|
||||
await migrate_customer(session, user_id, org)
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
|
||||
@@ -8,16 +8,10 @@ from server.verified_models.verified_model_service import (
|
||||
StoredVerifiedModel, # noqa: F401
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
|
||||
# Anything not loaded here may not have a table created for it.
|
||||
from storage.api_key import ApiKey # noqa: F401
|
||||
from storage.base import Base
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.device_code import DeviceCode # noqa: F401
|
||||
@@ -36,18 +30,9 @@ from storage.stripe_customer import StripeCustomer
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def db_path(tmp_path):
|
||||
"""Create a unique temp file path for each test."""
|
||||
return str(tmp_path / 'test.db')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine(db_path):
|
||||
"""Create a sync engine with tables using file-based DB."""
|
||||
engine = create_engine(
|
||||
f'sqlite:///{db_path}', connect_args={'check_same_thread': False}
|
||||
)
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
@@ -57,36 +42,6 @@ def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_engine(db_path):
|
||||
"""Create an async engine using the SAME file-based database."""
|
||||
async_engine = create_async_engine(
|
||||
f'sqlite+aiosqlite:///{db_path}',
|
||||
connect_args={'check_same_thread': False},
|
||||
)
|
||||
|
||||
async def create_tables():
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Run the async function synchronously
|
||||
import asyncio
|
||||
|
||||
asyncio.run(create_tables())
|
||||
return async_engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker bound to the async engine."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
return async_session_maker
|
||||
|
||||
|
||||
def add_minimal_fixtures(session_maker):
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
|
||||
@@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from integrations.jira.jira_manager import JiraManager
|
||||
from integrations.jira.jira_payload import JiraEventType, JiraWebhookPayload
|
||||
from integrations.models import Message, SourceType
|
||||
|
||||
from openhands.server.types import (
|
||||
LLMAuthenticationError,
|
||||
@@ -273,8 +274,9 @@ class TestSendMessage:
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
message = Message(source=SourceType.JIRA, message='Test message')
|
||||
result = await jira_manager.send_message(
|
||||
'Test message',
|
||||
message,
|
||||
'PROJ-123',
|
||||
'cloud-123',
|
||||
'service@test.com',
|
||||
|
||||
@@ -1,268 +0,0 @@
|
||||
"""
|
||||
Tests for JiraPayloadParser.
|
||||
|
||||
These tests verify the parsing behavior of Jira webhook payloads,
|
||||
including the handling of optional fields like user_email which
|
||||
may not be present in webhook payloads from Jira.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from integrations.jira.jira_payload import (
|
||||
JiraEventType,
|
||||
JiraPayloadError,
|
||||
JiraPayloadParser,
|
||||
JiraPayloadSkipped,
|
||||
JiraPayloadSuccess,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
"""Create a JiraPayloadParser with standard OpenHands labels."""
|
||||
return JiraPayloadParser(oh_label='openhands', inline_oh_label='@openhands')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_label_payload():
|
||||
"""Create a valid jira:issue_updated payload with OpenHands label."""
|
||||
return {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'issue': {
|
||||
'id': '12345',
|
||||
'key': 'TEST-123',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/issue/12345',
|
||||
},
|
||||
'user': {
|
||||
'displayName': 'Test User',
|
||||
'accountId': 'account-123',
|
||||
'emailAddress': 'test@example.com',
|
||||
},
|
||||
'changelog': {
|
||||
'items': [
|
||||
{
|
||||
'field': 'labels',
|
||||
'toString': 'openhands',
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_comment_payload():
|
||||
"""Create a valid comment_created payload with OpenHands mention."""
|
||||
return {
|
||||
'webhookEvent': 'comment_created',
|
||||
'issue': {
|
||||
'id': '12345',
|
||||
'key': 'TEST-123',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/issue/12345',
|
||||
},
|
||||
'comment': {
|
||||
'body': '@openhands please fix this bug',
|
||||
'author': {
|
||||
'displayName': 'Test User',
|
||||
'accountId': 'account-123',
|
||||
'emailAddress': 'test@example.com',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestUserEmailOptional:
|
||||
"""Tests verifying user_email is optional in webhook payloads.
|
||||
|
||||
Jira webhooks may not include emailAddress in the user data.
|
||||
The parser should accept payloads without this field.
|
||||
"""
|
||||
|
||||
def test_label_event_succeeds_without_email_address(
|
||||
self, parser, valid_label_payload
|
||||
):
|
||||
"""Verify label event parsing succeeds when emailAddress is missing."""
|
||||
# Arrange - remove emailAddress from user data
|
||||
del valid_label_payload['user']['emailAddress']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.user_email == ''
|
||||
assert result.payload.display_name == 'Test User'
|
||||
assert result.payload.account_id == 'account-123'
|
||||
|
||||
def test_comment_event_succeeds_without_email_address(
|
||||
self, parser, valid_comment_payload
|
||||
):
|
||||
"""Verify comment event parsing succeeds when emailAddress is missing."""
|
||||
# Arrange - remove emailAddress from author data
|
||||
del valid_comment_payload['comment']['author']['emailAddress']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_comment_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.user_email == ''
|
||||
assert result.payload.display_name == 'Test User'
|
||||
assert result.payload.account_id == 'account-123'
|
||||
|
||||
def test_user_email_preserved_when_present(self, parser, valid_label_payload):
|
||||
"""Verify user_email is captured when emailAddress is present."""
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.user_email == 'test@example.com'
|
||||
|
||||
|
||||
class TestRequiredFieldValidation:
|
||||
"""Tests verifying required fields are still validated."""
|
||||
|
||||
def test_missing_issue_id_returns_error(self, parser, valid_label_payload):
|
||||
"""Verify parsing fails when issue.id is missing."""
|
||||
# Arrange
|
||||
del valid_label_payload['issue']['id']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'issue.id' in result.error
|
||||
|
||||
def test_missing_issue_key_returns_error(self, parser, valid_label_payload):
|
||||
"""Verify parsing fails when issue.key is missing."""
|
||||
# Arrange
|
||||
del valid_label_payload['issue']['key']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'issue.key' in result.error
|
||||
|
||||
def test_missing_display_name_returns_error(self, parser, valid_label_payload):
|
||||
"""Verify parsing fails when user.displayName is missing."""
|
||||
# Arrange
|
||||
del valid_label_payload['user']['displayName']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'displayName' in result.error
|
||||
|
||||
def test_missing_account_id_returns_error(self, parser, valid_label_payload):
|
||||
"""Verify parsing fails when user.accountId is missing."""
|
||||
# Arrange
|
||||
del valid_label_payload['user']['accountId']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'accountId' in result.error
|
||||
|
||||
def test_missing_issue_self_url_returns_error(self, parser, valid_label_payload):
|
||||
"""Verify parsing fails when issue.self URL is missing."""
|
||||
# Arrange
|
||||
del valid_label_payload['issue']['self']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'workspace_name' in result.error or 'base_api_url' in result.error
|
||||
|
||||
|
||||
class TestEventTypeDetection:
|
||||
"""Tests for webhook event type detection."""
|
||||
|
||||
def test_issue_updated_with_label_returns_labeled_ticket(
|
||||
self, parser, valid_label_payload
|
||||
):
|
||||
"""Verify jira:issue_updated with label is detected as LABELED_TICKET."""
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.event_type == JiraEventType.LABELED_TICKET
|
||||
|
||||
def test_comment_created_with_mention_returns_comment_mention(
|
||||
self, parser, valid_comment_payload
|
||||
):
|
||||
"""Verify comment_created with mention is detected as COMMENT_MENTION."""
|
||||
# Act
|
||||
result = parser.parse(valid_comment_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.event_type == JiraEventType.COMMENT_MENTION
|
||||
|
||||
def test_unhandled_event_type_returns_skipped(self, parser):
|
||||
"""Verify unknown event types are skipped."""
|
||||
# Arrange
|
||||
payload = {'webhookEvent': 'jira:issue_deleted'}
|
||||
|
||||
# Act
|
||||
result = parser.parse(payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert 'Unhandled' in result.skip_reason
|
||||
|
||||
|
||||
class TestLabelFiltering:
|
||||
"""Tests for OpenHands label filtering."""
|
||||
|
||||
def test_label_event_without_openhands_label_skipped(
|
||||
self, parser, valid_label_payload
|
||||
):
|
||||
"""Verify label events without OpenHands label are skipped."""
|
||||
# Arrange - change label to something else
|
||||
valid_label_payload['changelog']['items'][0]['toString'] = 'other-label'
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert 'openhands' in result.skip_reason
|
||||
|
||||
|
||||
class TestCommentFiltering:
|
||||
"""Tests for OpenHands comment mention filtering."""
|
||||
|
||||
def test_comment_without_mention_skipped(self, parser, valid_comment_payload):
|
||||
"""Verify comments without OpenHands mention are skipped."""
|
||||
# Arrange - remove mention from comment body
|
||||
valid_comment_payload['comment']['body'] = 'Please fix this bug'
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_comment_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert '@openhands' in result.skip_reason
|
||||
|
||||
|
||||
class TestWorkspaceExtraction:
|
||||
"""Tests for workspace name extraction from issue URL."""
|
||||
|
||||
def test_workspace_name_extracted_from_self_url(self, parser, valid_label_payload):
|
||||
"""Verify workspace name is extracted from issue self URL."""
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.workspace_name == 'test.atlassian.net'
|
||||
assert result.payload.base_api_url == 'https://test.atlassian.net'
|
||||
@@ -738,7 +738,7 @@ class TestStartJob:
|
||||
# Should send error message about re-login
|
||||
jira_dc_manager.send_message.assert_called_once()
|
||||
call_args = jira_dc_manager.send_message.call_args[0]
|
||||
assert 'Please re-login' in call_args[0]
|
||||
assert 'Please re-login' in call_args[0].message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_llm_authentication_error(
|
||||
@@ -763,7 +763,7 @@ class TestStartJob:
|
||||
# Should send error message about LLM API key
|
||||
jira_dc_manager.send_message.assert_called_once()
|
||||
call_args = jira_dc_manager.send_message.call_args[0]
|
||||
assert 'valid LLM API key' in call_args[0]
|
||||
assert 'valid LLM API key' in call_args[0].message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_session_expired_error(
|
||||
@@ -788,8 +788,8 @@ class TestStartJob:
|
||||
# Should send error message about session expired
|
||||
jira_dc_manager.send_message.assert_called_once()
|
||||
call_args = jira_dc_manager.send_message.call_args[0]
|
||||
assert 'session has expired' in call_args[0]
|
||||
assert 'login again' in call_args[0]
|
||||
assert 'session has expired' in call_args[0].message
|
||||
assert 'login again' in call_args[0].message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_unexpected_error(
|
||||
@@ -814,7 +814,7 @@ class TestStartJob:
|
||||
# Should send generic error message
|
||||
jira_dc_manager.send_message.assert_called_once()
|
||||
call_args = jira_dc_manager.send_message.call_args[0]
|
||||
assert 'unexpected error' in call_args[0]
|
||||
assert 'unexpected error' in call_args[0].message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_send_message_fails(
|
||||
@@ -943,8 +943,9 @@ class TestSendMessage:
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
message = Message(source=SourceType.JIRA_DC, message='Test message')
|
||||
result = await jira_dc_manager.send_message(
|
||||
'Test message', 'PROJ-123', 'https://jira.company.com', 'bearer_token'
|
||||
message, 'PROJ-123', 'https://jira.company.com', 'bearer_token'
|
||||
)
|
||||
|
||||
assert result == {'id': 'comment_id'}
|
||||
@@ -1013,7 +1014,7 @@ class TestSendRepoSelectionComment:
|
||||
|
||||
jira_dc_manager.send_message.assert_called_once()
|
||||
call_args = jira_dc_manager.send_message.call_args[0]
|
||||
assert 'which repository to work with' in call_args[0]
|
||||
assert 'which repository to work with' in call_args[0].message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_repo_selection_comment_send_fails(
|
||||
|
||||
@@ -18,11 +18,9 @@ from openhands.core.schema.agent import AgentState
|
||||
class TestJiraDcNewConversationView:
|
||||
"""Tests for JiraDcNewConversationView"""
|
||||
|
||||
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = await new_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
|
||||
|
||||
assert instructions == 'Test Jira DC instructions template'
|
||||
assert 'PROJ-123' in user_msg
|
||||
@@ -85,9 +83,9 @@ class TestJiraDcNewConversationView:
|
||||
class TestJiraDcExistingConversationView:
|
||||
"""Tests for JiraDcExistingConversationView"""
|
||||
|
||||
async def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
|
||||
def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = await existing_conversation_view._get_instructions(
|
||||
instructions, user_msg = existing_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
|
||||
@@ -802,7 +802,7 @@ class TestStartJob:
|
||||
# Should send error message about re-login
|
||||
linear_manager.send_message.assert_called_once()
|
||||
call_args = linear_manager.send_message.call_args[0]
|
||||
assert 'Please re-login' in call_args[0]
|
||||
assert 'Please re-login' in call_args[0].message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_llm_authentication_error(
|
||||
@@ -828,7 +828,7 @@ class TestStartJob:
|
||||
# Should send error message about LLM API key
|
||||
linear_manager.send_message.assert_called_once()
|
||||
call_args = linear_manager.send_message.call_args[0]
|
||||
assert 'valid LLM API key' in call_args[0]
|
||||
assert 'valid LLM API key' in call_args[0].message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_session_expired_error(
|
||||
@@ -854,8 +854,8 @@ class TestStartJob:
|
||||
# Should send error message about session expired
|
||||
linear_manager.send_message.assert_called_once()
|
||||
call_args = linear_manager.send_message.call_args[0]
|
||||
assert 'session has expired' in call_args[0]
|
||||
assert 'login again' in call_args[0]
|
||||
assert 'session has expired' in call_args[0].message
|
||||
assert 'login again' in call_args[0].message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_unexpected_error(
|
||||
@@ -881,7 +881,7 @@ class TestStartJob:
|
||||
# Should send generic error message
|
||||
linear_manager.send_message.assert_called_once()
|
||||
call_args = linear_manager.send_message.call_args[0]
|
||||
assert 'unexpected error' in call_args[0]
|
||||
assert 'unexpected error' in call_args[0].message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_send_message_fails(
|
||||
@@ -1049,9 +1049,8 @@ class TestSendMessage:
|
||||
|
||||
linear_manager._query_api = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await linear_manager.send_message(
|
||||
'Test message', 'issue_id', 'api_key'
|
||||
)
|
||||
message = Message(source=SourceType.LINEAR, message='Test message')
|
||||
result = await linear_manager.send_message(message, 'issue_id', 'api_key')
|
||||
|
||||
assert result == mock_response
|
||||
linear_manager._query_api.assert_called_once()
|
||||
@@ -1115,7 +1114,7 @@ class TestSendRepoSelectionComment:
|
||||
|
||||
linear_manager.send_message.assert_called_once()
|
||||
call_args = linear_manager.send_message.call_args[0]
|
||||
assert 'which repository to work with' in call_args[0]
|
||||
assert 'which repository to work with' in call_args[0].message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_repo_selection_comment_send_fails(
|
||||
|
||||
@@ -18,11 +18,9 @@ from openhands.core.schema.agent import AgentState
|
||||
class TestLinearNewConversationView:
|
||||
"""Tests for LinearNewConversationView"""
|
||||
|
||||
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = await new_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
|
||||
|
||||
assert instructions == 'Test instructions template'
|
||||
assert 'TEST-123' in user_msg
|
||||
@@ -85,9 +83,9 @@ class TestLinearNewConversationView:
|
||||
class TestLinearExistingConversationView:
|
||||
"""Tests for LinearExistingConversationView"""
|
||||
|
||||
async def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
|
||||
def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = await existing_conversation_view._get_instructions(
|
||||
instructions, user_msg = existing_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
|
||||
@@ -263,9 +263,7 @@ class TestPausedSandboxResumption:
|
||||
@patch('openhands.app_server.config.get_httpx_client')
|
||||
@patch('openhands.app_server.event_callback.util.ensure_running_sandbox')
|
||||
@patch('openhands.app_server.event_callback.util.get_agent_server_url_from_sandbox')
|
||||
@patch.object(
|
||||
SlackUpdateExistingConversationView, '_get_instructions', new_callable=AsyncMock
|
||||
)
|
||||
@patch.object(SlackUpdateExistingConversationView, '_get_instructions')
|
||||
async def test_paused_sandbox_resumption(
|
||||
self,
|
||||
mock_get_instructions,
|
||||
|
||||
@@ -34,6 +34,7 @@ async def test_send_comment_to_jira_success(mock_jira_manager, processor):
|
||||
)
|
||||
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_manager.send_message = AsyncMock()
|
||||
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_jira('This is a summary.')
|
||||
@@ -129,6 +130,7 @@ async def test_call_sends_summary_to_jira(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_jira_manager.send_message = AsyncMock()
|
||||
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.jira_callback_processor.asyncio.create_task'
|
||||
@@ -198,6 +200,7 @@ async def test_send_comment_to_jira_api_error(mock_jira_manager, processor):
|
||||
)
|
||||
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
|
||||
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action - should not raise exception, but handle it gracefully
|
||||
await processor._send_comment_to_jira('This is a summary.')
|
||||
@@ -325,15 +328,18 @@ async def test_send_comment_to_jira_message_construction(mock_jira_manager, proc
|
||||
)
|
||||
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_manager.send_message = AsyncMock()
|
||||
mock_outgoing_message = MagicMock()
|
||||
mock_jira_manager.create_outgoing_message.return_value = mock_outgoing_message
|
||||
|
||||
test_message = 'This is a test summary message.'
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_jira(test_message)
|
||||
|
||||
# Assert - send_message now receives the string directly
|
||||
# Assert
|
||||
mock_jira_manager.create_outgoing_message.assert_called_once_with(msg=test_message)
|
||||
mock_jira_manager.send_message.assert_called_once_with(
|
||||
test_message,
|
||||
mock_outgoing_message,
|
||||
issue_key='TEST-123',
|
||||
jira_cloud_id='cloud123',
|
||||
svc_acc_email='service@test.com',
|
||||
@@ -380,6 +386,7 @@ async def test_call_creates_background_task_for_sending(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_jira_manager.send_message = AsyncMock()
|
||||
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.jira_callback_processor.asyncio.create_task'
|
||||
|
||||
@@ -32,6 +32,7 @@ async def test_send_comment_to_jira_dc_success(mock_jira_dc_manager, processor):
|
||||
)
|
||||
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_dc_manager.send_message = AsyncMock()
|
||||
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_jira_dc('This is a summary.')
|
||||
@@ -124,6 +125,7 @@ async def test_call_sends_summary_to_jira_dc(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_jira_dc_manager.send_message = AsyncMock()
|
||||
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.jira_dc_callback_processor.asyncio.create_task'
|
||||
@@ -198,6 +200,7 @@ async def test_send_comment_to_jira_dc_api_error(mock_jira_dc_manager, processor
|
||||
)
|
||||
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_dc_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
|
||||
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action - should not raise exception, but handle it gracefully
|
||||
await processor._send_comment_to_jira_dc('This is a summary.')
|
||||
@@ -325,15 +328,20 @@ async def test_send_comment_to_jira_dc_message_construction(
|
||||
)
|
||||
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_dc_manager.send_message = AsyncMock()
|
||||
mock_outgoing_message = MagicMock()
|
||||
mock_jira_dc_manager.create_outgoing_message.return_value = mock_outgoing_message
|
||||
|
||||
test_message = 'This is a test summary message.'
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_jira_dc(test_message)
|
||||
|
||||
# Assert - send_message now receives the string directly
|
||||
# Assert
|
||||
mock_jira_dc_manager.create_outgoing_message.assert_called_once_with(
|
||||
msg=test_message
|
||||
)
|
||||
mock_jira_dc_manager.send_message.assert_called_once_with(
|
||||
test_message,
|
||||
mock_outgoing_message,
|
||||
issue_key='TEST-123',
|
||||
base_api_url='https://test-jira-dc.company.com',
|
||||
svc_acc_api_key='decrypted_key',
|
||||
@@ -376,6 +384,7 @@ async def test_call_creates_background_task_for_sending(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_jira_dc_manager.send_message = AsyncMock()
|
||||
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.jira_dc_callback_processor.asyncio.create_task'
|
||||
|
||||
@@ -32,6 +32,7 @@ async def test_send_comment_to_linear_success(mock_linear_manager, processor):
|
||||
)
|
||||
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_linear_manager.send_message = AsyncMock()
|
||||
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_linear('This is a summary.')
|
||||
@@ -124,6 +125,7 @@ async def test_call_sends_summary_to_linear(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_linear_manager.send_message = AsyncMock()
|
||||
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.linear_callback_processor.asyncio.create_task'
|
||||
@@ -198,6 +200,7 @@ async def test_send_comment_to_linear_api_error(mock_linear_manager, processor):
|
||||
)
|
||||
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_linear_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
|
||||
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action - should not raise exception, but handle it gracefully
|
||||
await processor._send_comment_to_linear('This is a summary.')
|
||||
@@ -325,15 +328,20 @@ async def test_send_comment_to_linear_message_construction(
|
||||
)
|
||||
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_linear_manager.send_message = AsyncMock()
|
||||
mock_outgoing_message = MagicMock()
|
||||
mock_linear_manager.create_outgoing_message.return_value = mock_outgoing_message
|
||||
|
||||
test_message = 'This is a test summary message.'
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_linear(test_message)
|
||||
|
||||
# Assert - send_message now receives the string directly
|
||||
# Assert
|
||||
mock_linear_manager.create_outgoing_message.assert_called_once_with(
|
||||
msg=test_message
|
||||
)
|
||||
mock_linear_manager.send_message.assert_called_once_with(
|
||||
test_message,
|
||||
mock_outgoing_message,
|
||||
'TEST-123', # issue_id
|
||||
'decrypted_key', # api_key
|
||||
)
|
||||
@@ -375,6 +383,7 @@ async def test_call_creates_background_task_for_sending(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_linear_manager.send_message = AsyncMock()
|
||||
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.linear_callback_processor.asyncio.create_task'
|
||||
|
||||
@@ -16,15 +16,8 @@ from storage.device_code import DeviceCode
|
||||
|
||||
@pytest.fixture
|
||||
def mock_device_code_store():
|
||||
"""Mock device code store with async methods."""
|
||||
mock = MagicMock()
|
||||
mock.create_device_code = AsyncMock()
|
||||
mock.get_by_device_code = AsyncMock()
|
||||
mock.get_by_user_code = AsyncMock()
|
||||
mock.authorize_device_code = AsyncMock()
|
||||
mock.deny_device_code = AsyncMock()
|
||||
mock.update_poll_time = AsyncMock()
|
||||
return mock
|
||||
"""Mock device code store."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -61,7 +54,7 @@ class TestDeviceAuthorization:
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
current_interval=5, # Default interval
|
||||
)
|
||||
mock_store.create_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.create_device_code.return_value = mock_device
|
||||
|
||||
result = await device_authorization(mock_request)
|
||||
|
||||
@@ -83,7 +76,7 @@ class TestDeviceAuthorization:
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
current_interval=15, # Increased interval from previous rate limiting
|
||||
)
|
||||
mock_store.create_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.create_device_code.return_value = mock_device
|
||||
|
||||
result = await device_authorization(mock_request)
|
||||
|
||||
@@ -120,10 +113,10 @@ class TestDeviceToken:
|
||||
mock_device.status = status
|
||||
# Mock rate limiting - return False (not too fast) and default interval
|
||||
mock_device.check_rate_limit.return_value = (False, 5)
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
else:
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=None)
|
||||
mock_store.get_by_device_code.return_value = None
|
||||
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
@@ -149,14 +142,12 @@ class TestDeviceToken:
|
||||
)
|
||||
# Mock rate limiting - return False (not too fast) and default interval
|
||||
mock_device.check_rate_limit.return_value = (False, 5)
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
# Mock API key retrieval - use AsyncMock for async method
|
||||
# Mock API key retrieval
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.retrieve_api_key_by_name = AsyncMock(
|
||||
return_value='test-api-key'
|
||||
)
|
||||
mock_api_key_store.retrieve_api_key_by_name.return_value = 'test-api-key'
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -185,7 +176,7 @@ class TestDeviceVerificationAuthenticated:
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test verification with invalid device code."""
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=None)
|
||||
mock_store.get_by_user_code.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(
|
||||
@@ -198,7 +189,7 @@ class TestDeviceVerificationAuthenticated:
|
||||
"""Test verification with already processed device code."""
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = False
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(
|
||||
@@ -212,8 +203,8 @@ class TestDeviceVerificationAuthenticated:
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.authorize_device_code = AsyncMock(return_value=True)
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True
|
||||
|
||||
# Mock API key store with async create_api_key
|
||||
mock_api_key_store = MagicMock()
|
||||
@@ -257,17 +248,15 @@ class TestDeviceVerificationAuthenticated:
|
||||
mock_device2.is_pending.return_value = True
|
||||
|
||||
# Configure mock store to return appropriate device for each user_code
|
||||
async def get_by_user_code_side_effect(user_code):
|
||||
def get_by_user_code_side_effect(user_code):
|
||||
if user_code == device1_code:
|
||||
return mock_device1
|
||||
elif user_code == device2_code:
|
||||
return mock_device2
|
||||
return None
|
||||
|
||||
mock_store.get_by_user_code = AsyncMock(
|
||||
side_effect=get_by_user_code_side_effect
|
||||
)
|
||||
mock_store.authorize_device_code = AsyncMock(return_value=True)
|
||||
mock_store.get_by_user_code.side_effect = get_by_user_code_side_effect
|
||||
mock_store.authorize_device_code.return_value = True
|
||||
|
||||
# Authenticate first device
|
||||
result1 = await device_verification_authenticated(
|
||||
@@ -316,8 +305,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=None, # First poll
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -347,8 +336,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -378,8 +367,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -410,8 +399,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=last_poll,
|
||||
current_interval=15, # Already increased from previous slow_down
|
||||
)
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -441,8 +430,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=last_poll,
|
||||
current_interval=58, # Near maximum of 60
|
||||
)
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -468,8 +457,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -498,10 +487,8 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.authorize_device_code = AsyncMock(
|
||||
return_value=False
|
||||
) # Authorization fails
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = False # Authorization fails
|
||||
|
||||
# Mock API key store with async create_api_key
|
||||
mock_api_key_store = MagicMock()
|
||||
@@ -532,11 +519,9 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.authorize_device_code = AsyncMock(
|
||||
return_value=True
|
||||
) # Authorization succeeds
|
||||
mock_store.deny_device_code = AsyncMock(return_value=True) # Cleanup succeeds
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.deny_device_code.return_value = True # Cleanup succeeds
|
||||
|
||||
# Mock API key store to fail on creation (async)
|
||||
mock_api_key_store = MagicMock()
|
||||
@@ -574,12 +559,10 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.authorize_device_code = AsyncMock(
|
||||
return_value=True
|
||||
) # Authorization succeeds
|
||||
mock_store.deny_device_code = AsyncMock(
|
||||
side_effect=Exception('Cleanup failed')
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.deny_device_code.side_effect = Exception(
|
||||
'Cleanup failed'
|
||||
) # Cleanup fails
|
||||
|
||||
# Mock API key store to fail on creation (async)
|
||||
@@ -612,11 +595,8 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.authorize_device_code = AsyncMock(
|
||||
return_value=True
|
||||
) # Authorization succeeds
|
||||
mock_store.deny_device_code = AsyncMock()
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
|
||||
# Mock API key store with async create_api_key
|
||||
mock_api_key_store = MagicMock()
|
||||
|
||||
@@ -11,37 +11,43 @@ import httpx
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.testclient import TestClient
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
InsufficientPermissionError,
|
||||
InvalidRoleError,
|
||||
LastOwnerError,
|
||||
LiteLLMIntegrationError,
|
||||
MeResponse,
|
||||
OrgAppSettingsResponse,
|
||||
OrgAppSettingsUpdate,
|
||||
OrgAuthorizationError,
|
||||
OrgDatabaseError,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
OrgMemberUpdate,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrphanedUserError,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from server.routes.orgs import (
|
||||
get_me,
|
||||
get_org_members,
|
||||
org_router,
|
||||
remove_org_member,
|
||||
update_org_member,
|
||||
)
|
||||
from storage.org import Org
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
# Mock database before imports
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
InsufficientPermissionError,
|
||||
InvalidRoleError,
|
||||
LastOwnerError,
|
||||
LiteLLMIntegrationError,
|
||||
MeResponse,
|
||||
OrgAppSettingsResponse,
|
||||
OrgAppSettingsUpdate,
|
||||
OrgAuthorizationError,
|
||||
OrgDatabaseError,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
OrgMemberUpdate,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrphanedUserError,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from server.routes.orgs import (
|
||||
get_me,
|
||||
get_org_members,
|
||||
org_router,
|
||||
remove_org_member,
|
||||
update_org_member,
|
||||
)
|
||||
from storage.org import Org
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
# Test user ID constant (must be a valid UUID string)
|
||||
TEST_USER_ID = str(uuid.uuid4())
|
||||
|
||||
@@ -399,135 +399,3 @@ class TestUpdateActiveWorkingSeconds:
|
||||
assert conversation_work.seconds == 23.0
|
||||
assert conversation_work.conversation_id == conversation_id
|
||||
assert conversation_work.user_id == user_id
|
||||
|
||||
|
||||
class TestInvokeConversationCallbacks:
|
||||
"""Tests for invoke_conversation_callbacks function.
|
||||
|
||||
This function uses async database sessions (a_session_maker) to query
|
||||
and invoke callbacks for a conversation.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_observation(self):
|
||||
"""Create a mock AgentStateChangedObservation."""
|
||||
|
||||
observation = Mock(spec=AgentStateChangedObservation)
|
||||
observation.agent_state = AgentState.FINISHED
|
||||
return observation
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_async_session(self):
|
||||
"""Factory to create properly mocked async session context manager."""
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
def _create(callbacks_list):
|
||||
mock_session = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.scalars.return_value.all.return_value = callbacks_list
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock(return_value=None)
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_context_manager():
|
||||
yield mock_session
|
||||
|
||||
return mock_context_manager, mock_session
|
||||
|
||||
return _create
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_callbacks_with_active_callbacks(
|
||||
self, mock_observation, create_mock_async_session
|
||||
):
|
||||
"""Test that active callbacks are invoked successfully."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
# Arrange
|
||||
conversation_id = 'test_conversation_callbacks'
|
||||
mock_processor = AsyncMock(return_value=None)
|
||||
|
||||
# Create a mock callback
|
||||
mock_callback = Mock()
|
||||
mock_callback.id = 1
|
||||
mock_callback.processor_type = 'test_processor'
|
||||
mock_callback.get_processor.return_value = mock_processor
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session([mock_callback])
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.a_session_maker',
|
||||
mock_context_manager,
|
||||
):
|
||||
from server.utils.conversation_callback_utils import (
|
||||
invoke_conversation_callbacks,
|
||||
)
|
||||
|
||||
await invoke_conversation_callbacks(conversation_id, mock_observation)
|
||||
|
||||
# Assert
|
||||
mock_callback.get_processor.assert_called_once()
|
||||
mock_processor.assert_called_once_with(mock_callback, mock_observation)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_callbacks_with_no_active_callbacks(
|
||||
self, mock_observation, create_mock_async_session
|
||||
):
|
||||
"""Test behavior when no active callbacks exist."""
|
||||
# Arrange
|
||||
conversation_id = 'test_no_callbacks'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session([])
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.a_session_maker',
|
||||
mock_context_manager,
|
||||
):
|
||||
from server.utils.conversation_callback_utils import (
|
||||
invoke_conversation_callbacks,
|
||||
)
|
||||
|
||||
await invoke_conversation_callbacks(conversation_id, mock_observation)
|
||||
|
||||
# Assert - should complete without errors
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_callbacks_handles_processor_exception(
|
||||
self, mock_observation, create_mock_async_session
|
||||
):
|
||||
"""Test that processor exceptions are caught and callback status is updated."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
# Arrange
|
||||
conversation_id = 'test_callback_error'
|
||||
mock_processor = AsyncMock(side_effect=Exception('Processor error'))
|
||||
|
||||
mock_callback = Mock()
|
||||
mock_callback.id = 1
|
||||
mock_callback.processor_type = 'failing_processor'
|
||||
mock_callback.get_processor.return_value = mock_processor
|
||||
mock_callback.status = 'active'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session([mock_callback])
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.a_session_maker',
|
||||
mock_context_manager,
|
||||
), patch('server.utils.conversation_callback_utils.logger') as mock_logger:
|
||||
from server.utils.conversation_callback_utils import (
|
||||
invoke_conversation_callbacks,
|
||||
)
|
||||
from storage.conversation_callback import CallbackStatus
|
||||
|
||||
await invoke_conversation_callbacks(conversation_id, mock_observation)
|
||||
|
||||
# Assert - callback status should be set to ERROR
|
||||
assert mock_callback.status == CallbackStatus.ERROR
|
||||
mock_logger.error.assert_called_once()
|
||||
error_call = mock_logger.error.call_args
|
||||
assert error_call[0][0] == 'callback_invocation_failed'
|
||||
|
||||
@@ -1,127 +1,127 @@
|
||||
"""Unit tests for AuthTokenStore using SQLite in-memory database."""
|
||||
"""Unit tests for AuthTokenStore."""
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from server.auth.auth_error import TokenRefreshError
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from storage.auth_token_store import (
|
||||
ACCESS_TOKEN_EXPIRY_BUFFER,
|
||||
LOCK_TIMEOUT_SECONDS,
|
||||
AuthTokenStore,
|
||||
)
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.base import Base
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
)
|
||||
return engine
|
||||
def create_mock_session():
|
||||
"""Create a mock async session with properly configured context managers."""
|
||||
session = AsyncMock()
|
||||
|
||||
# Create async context manager for begin()
|
||||
@asynccontextmanager
|
||||
async def begin_context():
|
||||
yield
|
||||
|
||||
session.begin = begin_context
|
||||
return session
|
||||
|
||||
|
||||
def create_mock_session_maker(mock_session):
|
||||
"""Create a mock async session maker."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_context():
|
||||
yield mock_session
|
||||
|
||||
# Return a callable that returns the context manager
|
||||
return lambda: session_context()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker bound to the async engine."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
def mock_session():
|
||||
"""Create mock async session."""
|
||||
return create_mock_session()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Create mock async session maker."""
|
||||
return create_mock_session_maker(mock_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_token_store(mock_session_maker):
|
||||
"""Create AuthTokenStore instance with mocked session maker."""
|
||||
return AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
# Create all tables
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
return async_session_maker
|
||||
|
||||
|
||||
class TestIsTokenExpired:
|
||||
"""Tests for _is_token_expired method."""
|
||||
|
||||
def test_both_tokens_valid(self):
|
||||
def test_both_tokens_valid(self, auth_token_store):
|
||||
"""Test when both tokens are valid (not expired)."""
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
refresh_expires = current_time + 1000
|
||||
|
||||
access_expired, refresh_expired = store._is_token_expired(
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is False
|
||||
|
||||
def test_access_token_expired(self):
|
||||
def test_access_token_expired(self, auth_token_store):
|
||||
"""Test when access token is expired but within buffer."""
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
current_time = int(time.time())
|
||||
# Access token expires within buffer period
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER - 100
|
||||
refresh_expires = current_time + 10000
|
||||
|
||||
access_expired, refresh_expired = store._is_token_expired(
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is True
|
||||
assert refresh_expired is False
|
||||
|
||||
def test_refresh_token_expired(self):
|
||||
def test_refresh_token_expired(self, auth_token_store):
|
||||
"""Test when refresh token is expired."""
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
refresh_expires = current_time - 100 # Already expired
|
||||
|
||||
access_expired, refresh_expired = store._is_token_expired(
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is True
|
||||
|
||||
def test_both_tokens_expired(self):
|
||||
def test_both_tokens_expired(self, auth_token_store):
|
||||
"""Test when both tokens are expired."""
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time - 100
|
||||
refresh_expires = current_time - 100
|
||||
|
||||
access_expired, refresh_expired = store._is_token_expired(
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is True
|
||||
assert refresh_expired is True
|
||||
|
||||
def test_zero_expiration_treated_as_never_expires(self):
|
||||
def test_zero_expiration_treated_as_never_expires(self, auth_token_store):
|
||||
"""Test that 0 expiration time is treated as never expires."""
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
access_expired, refresh_expired = store._is_token_expired(0, 0)
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(0, 0)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is False
|
||||
@@ -131,188 +131,427 @@ class TestLoadTokensFastPath:
|
||||
"""Tests for load_tokens fast path (no lock needed)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_token_not_found(self, async_session_maker):
|
||||
async def test_fast_path_token_not_found(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test fast path returns None when no token record exists."""
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await store.load_tokens()
|
||||
result = await auth_token_store.load_tokens()
|
||||
|
||||
assert result is None
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_valid_token_no_refresh_needed(self, async_session_maker):
|
||||
async def test_fast_path_valid_token_no_refresh_needed(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test fast path returns tokens when they are still valid."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'valid-access-token'
|
||||
mock_token.refresh_token = 'valid-refresh-token'
|
||||
mock_token.access_token_expires_at = (
|
||||
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
)
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
# First, store a valid token in the database
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
await store.store_tokens(
|
||||
access_token='valid-access-token',
|
||||
refresh_token='valid-refresh-token',
|
||||
access_token_expires_at=current_time
|
||||
+ ACCESS_TOKEN_EXPIRY_BUFFER
|
||||
+ 1000,
|
||||
refresh_token_expires_at=current_time + 10000,
|
||||
)
|
||||
result = await auth_token_store.load_tokens()
|
||||
|
||||
# Now load tokens - should return valid tokens without refresh
|
||||
result = await store.load_tokens()
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'valid-access-token'
|
||||
assert result['refresh_token'] == 'valid-refresh-token'
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'valid-access-token'
|
||||
assert result['refresh_token'] == 'valid-refresh-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_no_refresh_callback_provided(self, async_session_maker):
|
||||
async def test_fast_path_no_refresh_callback_provided(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test fast path returns existing tokens when no refresh callback is provided."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'expired-access-token'
|
||||
mock_token.refresh_token = 'valid-refresh-token'
|
||||
# Expired access token
|
||||
mock_token.access_token_expires_at = current_time - 100
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
# Store expired access token
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
await store.store_tokens(
|
||||
access_token='expired-access-token',
|
||||
refresh_token='valid-refresh-token',
|
||||
access_token_expires_at=current_time - 100, # Expired
|
||||
refresh_token_expires_at=current_time + 10000,
|
||||
)
|
||||
result = await auth_token_store.load_tokens(check_expiration_and_refresh=None)
|
||||
|
||||
# Load without refresh callback - should still return tokens
|
||||
result = await store.load_tokens(check_expiration_and_refresh=None)
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'expired-access-token'
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'expired-access-token'
|
||||
|
||||
|
||||
class TestLoadTokensSlowPath:
|
||||
"""Tests for load_tokens slow path (lock required for refresh).
|
||||
"""Tests for load_tokens slow path (lock required for refresh)."""
|
||||
|
||||
Note: These tests require PostgreSQL's lock_timeout feature which is not
|
||||
available in SQLite. The slow path tests are skipped when using SQLite.
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason='SQLite does not support PostgreSQL lock_timeout syntax')
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_successful_refresh(self, async_session_maker):
|
||||
async def test_slow_path_successful_refresh(self):
|
||||
"""Test slow path successfully refreshes expired tokens."""
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason='SQLite does not support PostgreSQL lock_timeout syntax')
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_callback_returns_none(self, async_session_maker):
|
||||
"""Test behavior when refresh callback returns None (no refresh performed)."""
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_double_check_avoids_refresh(self, async_session_maker):
|
||||
"""Test double-check pattern avoids unnecessary refresh."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
# First call (fast path) - returns expired token
|
||||
# Second call (slow path) - returns same token for update
|
||||
expired_token = MagicMock()
|
||||
expired_token.id = 1
|
||||
expired_token.access_token = 'expired-access-token'
|
||||
expired_token.refresh_token = 'valid-refresh-token'
|
||||
expired_token.access_token_expires_at = current_time - 100 # Expired
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-access-token',
|
||||
'refresh_token': 'new-refresh-token',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'new-access-token'
|
||||
assert result['refresh_token'] == 'new-refresh-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_double_check_avoids_refresh(self):
|
||||
"""Test double-check locking: token was refreshed by another request."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# Simulate scenario:
|
||||
# 1. Fast path sees expired token
|
||||
# 2. While waiting for lock, another request refreshes
|
||||
# 3. Slow path sees fresh token, skips refresh
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def create_token():
|
||||
call_count[0] += 1
|
||||
token = MagicMock()
|
||||
token.id = 1
|
||||
token.access_token = 'fresh-access-token'
|
||||
token.refresh_token = 'fresh-refresh-token'
|
||||
if call_count[0] == 1:
|
||||
# First call (fast path) - expired
|
||||
token.access_token_expires_at = current_time - 100
|
||||
else:
|
||||
# Second call (slow path) - already refreshed
|
||||
token.access_token_expires_at = (
|
||||
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
)
|
||||
token.refresh_token_expires_at = current_time + 86400
|
||||
return token
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.side_effect = (
|
||||
lambda: create_token()
|
||||
)
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
refresh_called = [False]
|
||||
|
||||
async def mock_refresh(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int]:
|
||||
refresh_called[0] = True
|
||||
return {
|
||||
'access_token': 'should-not-be-used',
|
||||
'refresh_token': 'should-not-be-used',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
# The refresh callback should not be called because double-check
|
||||
# found the token was already refreshed
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'fresh-access-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_token_not_found_after_lock(self):
|
||||
"""Test slow path returns None if token record disappears after lock."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# First call (fast path) - token exists but expired
|
||||
# Second call (slow path with lock) - token no longer exists
|
||||
call_count = [0]
|
||||
|
||||
def get_token():
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
token = MagicMock()
|
||||
token.access_token_expires_at = current_time - 100 # Expired
|
||||
token.refresh_token_expires_at = current_time + 10000
|
||||
return token
|
||||
return None
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.side_effect = get_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLoadTokensLockTimeout:
|
||||
"""Tests for lock timeout handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_timeout_raises_token_refresh_error(self):
|
||||
"""Test that lock timeout raises TokenRefreshError."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# First call (fast path) - returns expired token
|
||||
expired_token = MagicMock()
|
||||
expired_token.access_token_expires_at = current_time - 100
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
|
||||
# First execute for fast path succeeds
|
||||
# Second execute (for slow path) raises OperationalError
|
||||
call_count = [0]
|
||||
|
||||
async def execute_side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 1:
|
||||
return mock_result
|
||||
# Simulate lock timeout
|
||||
raise OperationalError(
|
||||
'canceling statement due to lock timeout', None, None
|
||||
)
|
||||
|
||||
# Store a token that will be valid when second check happens
|
||||
await store.store_tokens(
|
||||
access_token='original-access-token',
|
||||
refresh_token='valid-refresh-token',
|
||||
access_token_expires_at=current_time
|
||||
+ ACCESS_TOKEN_EXPIRY_BUFFER
|
||||
+ 1000,
|
||||
refresh_token_expires_at=current_time + 10000,
|
||||
)
|
||||
mock_session.execute = execute_side_effect
|
||||
|
||||
# Load with refresh callback - should NOT refresh since token is valid
|
||||
result = await store.load_tokens()
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'original-access-token'
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
with pytest.raises(TokenRefreshError) as exc_info:
|
||||
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert 'lock timeout' in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_timeout_preserves_original_exception(self):
|
||||
"""Test that TokenRefreshError preserves the original OperationalError."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
expired_token = MagicMock()
|
||||
expired_token.access_token_expires_at = current_time - 100
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
|
||||
original_error = OperationalError(
|
||||
'canceling statement due to lock timeout', None, None
|
||||
)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
async def execute_side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 1:
|
||||
return mock_result
|
||||
raise original_error
|
||||
|
||||
mock_session.execute = execute_side_effect
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
with pytest.raises(TokenRefreshError) as exc_info:
|
||||
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
# Verify the original exception is chained
|
||||
assert exc_info.value.__cause__ is original_error
|
||||
|
||||
|
||||
class TestLoadTokensRefreshCallbackBehavior:
|
||||
"""Tests for refresh callback return values."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_callback_returns_none(self):
|
||||
"""Test behavior when refresh callback returns None (no refresh performed)."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
expired_token = MagicMock()
|
||||
expired_token.id = 1
|
||||
expired_token.access_token = 'old-access-token'
|
||||
expired_token.refresh_token = 'old-refresh-token'
|
||||
expired_token.access_token_expires_at = current_time - 100 # Expired
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh_returns_none(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int] | None:
|
||||
return None
|
||||
|
||||
result = await auth_store.load_tokens(
|
||||
check_expiration_and_refresh=mock_refresh_returns_none
|
||||
)
|
||||
|
||||
# Should return the old tokens when refresh returns None
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'old-access-token'
|
||||
assert result['refresh_token'] == 'old-refresh-token'
|
||||
|
||||
|
||||
class TestStoreTokens:
|
||||
"""Tests for store_tokens method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_tokens_creates_new_record(self, async_session_maker):
|
||||
async def test_store_tokens_creates_new_record(self):
|
||||
"""Test storing tokens when no existing record."""
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
mock_session = create_mock_session()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
await store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
# Verify the token was stored
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == 'test-user-123',
|
||||
AuthTokens.identity_provider == ProviderType.GITHUB.value,
|
||||
)
|
||||
)
|
||||
token_record = result.scalars().first()
|
||||
assert token_record is not None
|
||||
assert token_record.access_token == 'new-access-token'
|
||||
assert token_record.refresh_token == 'new-refresh-token'
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
await auth_store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_tokens_updates_existing_record(self, async_session_maker):
|
||||
async def test_store_tokens_updates_existing_record(self):
|
||||
"""Test storing tokens updates existing record."""
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
mock_session = create_mock_session()
|
||||
existing_token = MagicMock()
|
||||
existing_token.access_token = 'old-access'
|
||||
|
||||
# First, create a token record
|
||||
await store.store_tokens(
|
||||
access_token='old-access-token',
|
||||
refresh_token='old-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = existing_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
# Now update it
|
||||
await store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567891,
|
||||
refresh_token_expires_at=1234657891,
|
||||
)
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
# Verify the token was updated
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == 'test-user-123',
|
||||
AuthTokens.identity_provider == ProviderType.GITHUB.value,
|
||||
)
|
||||
)
|
||||
token_record = result.scalars().first()
|
||||
assert token_record is not None
|
||||
assert token_record.access_token == 'new-access-token'
|
||||
assert token_record.refresh_token == 'new-refresh-token'
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
await auth_store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
|
||||
assert existing_token.access_token == 'new-access-token'
|
||||
assert existing_token.refresh_token == 'new-refresh-token'
|
||||
|
||||
|
||||
class TestIsAccessTokenValid:
|
||||
@@ -320,93 +559,80 @@ class TestIsAccessTokenValid:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_false_when_no_tokens(
|
||||
self, async_session_maker
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test returns False when no tokens found."""
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await store.is_access_token_valid()
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
|
||||
assert result is False
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_true_for_valid_token(
|
||||
self, async_session_maker
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test returns True when token is valid."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'valid-access'
|
||||
mock_token.refresh_token = 'valid-refresh'
|
||||
mock_token.access_token_expires_at = current_time + 1000
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
await store.store_tokens(
|
||||
access_token='valid-access',
|
||||
refresh_token='valid-refresh',
|
||||
access_token_expires_at=current_time + 1000,
|
||||
refresh_token_expires_at=current_time + 10000,
|
||||
)
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
|
||||
result = await store.is_access_token_valid()
|
||||
|
||||
assert result is True
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_false_for_expired_token(
|
||||
self, async_session_maker
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test returns False when token is expired."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'expired-access'
|
||||
mock_token.refresh_token = 'valid-refresh'
|
||||
mock_token.access_token_expires_at = current_time - 100 # Expired
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
await store.store_tokens(
|
||||
access_token='expired-access',
|
||||
refresh_token='valid-refresh',
|
||||
access_token_expires_at=current_time - 100, # Expired
|
||||
refresh_token_expires_at=current_time + 10000,
|
||||
)
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
|
||||
result = await store.is_access_token_valid()
|
||||
|
||||
assert result is False
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetInstance:
|
||||
"""Tests for get_instance class method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_creates_auth_token_store(self, async_session_maker):
|
||||
async def test_get_instance_creates_auth_token_store(self):
|
||||
"""Test get_instance creates an AuthTokenStore with correct params."""
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
with patch('storage.auth_token_store.a_session_maker') as mock_a_session_maker:
|
||||
store = await AuthTokenStore.get_instance(
|
||||
keycloak_user_id='user-123', idp=ProviderType.GITHUB
|
||||
)
|
||||
|
||||
assert store.keycloak_user_id == 'user-123'
|
||||
assert store.idp == ProviderType.GITHUB
|
||||
assert store.a_session_maker is mock_a_session_maker
|
||||
|
||||
|
||||
class TestIdentityProviderValue:
|
||||
"""Tests for identity_provider_value property."""
|
||||
|
||||
def test_identity_provider_value_returns_idp_value(self):
|
||||
def test_identity_provider_value_returns_idp_value(self, auth_token_store):
|
||||
"""Test that identity_provider_value returns the enum value."""
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
assert store.identity_provider_value == ProviderType.GITHUB.value
|
||||
assert auth_token_store.identity_provider_value == ProviderType.GITHUB.value
|
||||
|
||||
def test_identity_provider_value_for_different_providers(self):
|
||||
"""Test identity_provider_value for different providers."""
|
||||
@@ -418,6 +644,7 @@ class TestIdentityProviderValue:
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=provider,
|
||||
a_session_maker=MagicMock(),
|
||||
)
|
||||
assert store.identity_provider_value == provider.value
|
||||
|
||||
|
||||
@@ -1,17 +1,33 @@
|
||||
"""Unit tests for DeviceCodeStore."""
|
||||
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.device_code import DeviceCode
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device_code_store():
|
||||
def mock_session():
|
||||
"""Mock database session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Mock session maker."""
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = mock_session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device_code_store(mock_session_maker):
|
||||
"""Create DeviceCodeStore instance."""
|
||||
return DeviceCodeStore()
|
||||
return DeviceCodeStore(mock_session_maker)
|
||||
|
||||
|
||||
class TestDeviceCodeStore:
|
||||
@@ -33,257 +49,145 @@ class TestDeviceCodeStore:
|
||||
assert len(code) == 128
|
||||
assert code.isalnum()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_device_code_success(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
def test_create_device_code_success(self, device_code_store, mock_session):
|
||||
"""Test successful device code creation."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.create_device_code(expires_in=600)
|
||||
# Mock successful creation (no IntegrityError)
|
||||
mock_device_code = MagicMock(spec=DeviceCode)
|
||||
mock_device_code.device_code = 'test-device-code-123'
|
||||
mock_device_code.user_code = 'TESTCODE'
|
||||
|
||||
# Mock the session to return our mock device code after refresh
|
||||
def mock_refresh(obj):
|
||||
obj.device_code = mock_device_code.device_code
|
||||
obj.user_code = mock_device_code.user_code
|
||||
|
||||
mock_session.refresh.side_effect = mock_refresh
|
||||
|
||||
result = device_code_store.create_device_code(expires_in=600)
|
||||
|
||||
assert isinstance(result, DeviceCode)
|
||||
assert len(result.device_code) == 128
|
||||
assert len(result.user_code) == 8
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
# Verify the DeviceCode was created in the database
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(DeviceCode).filter(DeviceCode.device_code == result.device_code)
|
||||
)
|
||||
device_code = result_db.scalars().first()
|
||||
assert device_code is not None
|
||||
assert device_code.user_code == result.user_code
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_device_code_with_retries(
|
||||
self, device_code_store, async_session_maker
|
||||
def test_create_device_code_with_retries(
|
||||
self, device_code_store, mock_session_maker
|
||||
):
|
||||
"""Test device code creation with constraint violation retries."""
|
||||
# First create a device code to cause a collision
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
first_code = await device_code_store.create_device_code(expires_in=600)
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
|
||||
# Patch generate methods to return the same codes on first attempt,
|
||||
# then different codes on second attempt
|
||||
call_count = {'user': 0, 'device': 0}
|
||||
original_generate_user_code = device_code_store.generate_user_code
|
||||
original_generate_device_code = device_code_store.generate_device_code
|
||||
# First attempt fails with IntegrityError, second succeeds
|
||||
mock_session.commit.side_effect = [IntegrityError('', '', ''), None]
|
||||
|
||||
def mock_generate_user_code():
|
||||
call_count['user'] += 1
|
||||
if call_count['user'] == 1:
|
||||
return first_code.user_code # Collision
|
||||
return original_generate_user_code()
|
||||
mock_device_code = MagicMock(spec=DeviceCode)
|
||||
mock_device_code.device_code = 'test-device-code-456'
|
||||
mock_device_code.user_code = 'TESTCD2'
|
||||
|
||||
def mock_generate_device_code():
|
||||
call_count['device'] += 1
|
||||
if call_count['device'] == 1:
|
||||
return first_code.device_code # Collision
|
||||
return original_generate_device_code()
|
||||
def mock_refresh(obj):
|
||||
obj.device_code = mock_device_code.device_code
|
||||
obj.user_code = mock_device_code.user_code
|
||||
|
||||
device_code_store.generate_user_code = mock_generate_user_code
|
||||
device_code_store.generate_device_code = mock_generate_device_code
|
||||
mock_session.refresh.side_effect = mock_refresh
|
||||
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.create_device_code(expires_in=600)
|
||||
store = DeviceCodeStore(mock_session_maker)
|
||||
result = store.create_device_code(expires_in=600)
|
||||
|
||||
assert isinstance(result, DeviceCode)
|
||||
assert result.device_code != first_code.device_code # Should be different
|
||||
assert call_count['user'] == 2 # Two attempts
|
||||
assert mock_session.add.call_count == 2 # Two attempts
|
||||
assert mock_session.commit.call_count == 2 # Two attempts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_device_code_max_attempts_exceeded(
|
||||
self, device_code_store, async_session_maker
|
||||
def test_create_device_code_max_attempts_exceeded(
|
||||
self, device_code_store, mock_session_maker
|
||||
):
|
||||
"""Test device code creation failure after max attempts."""
|
||||
# First create a device code
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
first_code = await device_code_store.create_device_code(expires_in=600)
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
|
||||
# Always return the same codes to cause repeated collisions
|
||||
device_code_store.generate_user_code = lambda: first_code.user_code
|
||||
device_code_store.generate_device_code = lambda: first_code.device_code
|
||||
# All attempts fail with IntegrityError
|
||||
mock_session.commit.side_effect = IntegrityError('', '', '')
|
||||
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match='Failed to generate unique device codes after 3 attempts',
|
||||
):
|
||||
await device_code_store.create_device_code(
|
||||
expires_in=600, max_attempts=3
|
||||
)
|
||||
store = DeviceCodeStore(mock_session_maker)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_device_code(self, device_code_store, async_session_maker):
|
||||
"""Test getting device code by device code."""
|
||||
# Create a device code first
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
result = await device_code_store.get_by_device_code(created.device_code)
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match='Failed to generate unique device codes after 3 attempts',
|
||||
):
|
||||
store.create_device_code(expires_in=600, max_attempts=3)
|
||||
|
||||
assert result is not None
|
||||
assert result.device_code == created.device_code
|
||||
assert result.user_code == created.user_code
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_device_code_not_found(
|
||||
self, device_code_store, async_session_maker
|
||||
@pytest.mark.parametrize(
|
||||
'lookup_method,lookup_field',
|
||||
[
|
||||
('get_by_device_code', 'device_code'),
|
||||
('get_by_user_code', 'user_code'),
|
||||
],
|
||||
)
|
||||
def test_lookup_methods(
|
||||
self, device_code_store, mock_session, lookup_method, lookup_field
|
||||
):
|
||||
"""Test getting non-existent device code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.get_by_device_code('non-existent-code')
|
||||
"""Test device code lookup methods."""
|
||||
test_code = 'test-code-123'
|
||||
mock_device_code = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_device_code
|
||||
)
|
||||
|
||||
assert result is None
|
||||
result = getattr(device_code_store, lookup_method)(test_code)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_code(self, device_code_store, async_session_maker):
|
||||
"""Test getting device code by user code."""
|
||||
# Create a device code first
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
result = await device_code_store.get_by_user_code(created.user_code)
|
||||
assert result == mock_device_code
|
||||
mock_session.query.assert_called_once_with(DeviceCode)
|
||||
mock_session.query.return_value.filter_by.assert_called_once_with(
|
||||
**{lookup_field: test_code}
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.device_code == created.device_code
|
||||
assert result.user_code == created.user_code
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_code_not_found(
|
||||
self, device_code_store, async_session_maker
|
||||
@pytest.mark.parametrize(
|
||||
'device_exists,is_pending,expected_result',
|
||||
[
|
||||
(True, True, True), # Success case
|
||||
(False, True, False), # Device not found
|
||||
(True, False, False), # Device not pending
|
||||
],
|
||||
)
|
||||
def test_authorize_device_code(
|
||||
self,
|
||||
device_code_store,
|
||||
mock_session,
|
||||
device_exists,
|
||||
is_pending,
|
||||
expected_result,
|
||||
):
|
||||
"""Test getting non-existent user code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.get_by_user_code('NOTFOUND')
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_device_code_success(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test successful device code authorization."""
|
||||
"""Test device code authorization."""
|
||||
user_code = 'ABC12345'
|
||||
user_id = 'test-user-123'
|
||||
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
result = await device_code_store.authorize_device_code(
|
||||
created.user_code, user_id
|
||||
)
|
||||
if device_exists:
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = is_pending
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_device
|
||||
else:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
result = device_code_store.authorize_device_code(user_code, user_id)
|
||||
|
||||
assert result == expected_result
|
||||
if expected_result:
|
||||
mock_device.authorize.assert_called_once_with(user_id)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_deny_device_code(self, device_code_store, mock_session):
|
||||
"""Test device code denial."""
|
||||
user_code = 'ABC12345'
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_device
|
||||
)
|
||||
|
||||
result = device_code_store.deny_device_code(user_code)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the device code was authorized in the database
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(DeviceCode).filter(DeviceCode.user_code == created.user_code)
|
||||
)
|
||||
device_code = result_db.scalars().first()
|
||||
assert device_code.status == 'authorized'
|
||||
assert device_code.keycloak_user_id == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_device_code_not_found(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test authorizing non-existent device code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.authorize_device_code(
|
||||
'NOTFOUND', 'user-123'
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_device_code_not_pending(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test authorizing already authorized device code."""
|
||||
user_id = 'test-user-123'
|
||||
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
# First authorization
|
||||
await device_code_store.authorize_device_code(created.user_code, user_id)
|
||||
# Second authorization should fail
|
||||
result = await device_code_store.authorize_device_code(
|
||||
created.user_code, 'another-user'
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_device_code_success(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test successful device code denial."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
result = await device_code_store.deny_device_code(created.user_code)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the device code was denied in the database
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(DeviceCode).filter(DeviceCode.user_code == created.user_code)
|
||||
)
|
||||
device_code = result_db.scalars().first()
|
||||
assert device_code.status == 'denied'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_device_code_not_found(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test denying non-existent device code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.deny_device_code('NOTFOUND')
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_device_code_not_pending(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test denying already denied device code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
# First denial
|
||||
await device_code_store.deny_device_code(created.user_code)
|
||||
# Second denial should fail
|
||||
result = await device_code_store.deny_device_code(created.user_code)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_poll_time_success(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test updating poll time."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
original_interval = created.current_interval
|
||||
result = await device_code_store.update_poll_time(
|
||||
created.device_code, increase_interval=True
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the poll time was updated
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(DeviceCode).filter(DeviceCode.device_code == created.device_code)
|
||||
)
|
||||
device_code = result_db.scalars().first()
|
||||
assert device_code.current_interval > original_interval
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_poll_time_not_found(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test updating poll time for non-existent device code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.update_poll_time(
|
||||
'non-existent-code', increase_interval=False
|
||||
)
|
||||
|
||||
assert result is False
|
||||
mock_device.deny.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@@ -9,35 +9,16 @@ from storage.base import Base
|
||||
from storage.gitlab_webhook import GitlabWebhook
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
# Use module-scoped engine to share database across fixtures
|
||||
_test_engine = None
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
async def async_engine(event_loop):
|
||||
"""Create an async SQLite engine for testing.
|
||||
|
||||
This fixture creates an in-memory SQLite database and ensures
|
||||
all tables are created before tests run.
|
||||
"""
|
||||
global _test_engine
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
_test_engine = engine
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
@@ -48,7 +29,7 @@ async def async_engine(event_loop):
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker for testing."""
|
||||
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
@@ -56,21 +37,8 @@ async def async_session_maker(async_engine):
|
||||
|
||||
@pytest.fixture
|
||||
async def webhook_store(async_session_maker):
|
||||
"""Create a GitlabWebhookStore instance for testing.
|
||||
|
||||
This fixture injects the test's async_session_maker to ensure
|
||||
the store uses the same in-memory database as the test fixtures.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
|
||||
store = GitlabWebhookStore()
|
||||
|
||||
# Inject the test session maker - this needs to replace the module-level import
|
||||
import storage.gitlab_webhook_store as store_module
|
||||
|
||||
store_module.a_session_maker = async_session_maker
|
||||
|
||||
return store
|
||||
"""Create a GitlabWebhookStore instance for testing."""
|
||||
return GitlabWebhookStore(a_session_maker=async_session_maker)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -134,7 +102,7 @@ class TestGetWebhookByResourceOnly:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project_webhook_by_resource_only(
|
||||
self, webhook_store, sample_webhooks
|
||||
self, webhook_store, async_session_maker, sample_webhooks
|
||||
):
|
||||
"""Test getting a project webhook by resource ID without user_id filter."""
|
||||
# Arrange
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
"""
|
||||
Tests for JiraIntegrationStore async methods.
|
||||
|
||||
The store uses async database sessions (a_session_maker) for all operations,
|
||||
which is critical for avoiding asyncpg event loop issues when called from
|
||||
FastAPI async endpoints.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
"""Create a JiraIntegrationStore instance."""
|
||||
return JiraIntegrationStore()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_async_session():
|
||||
"""Factory to create properly mocked async session context manager."""
|
||||
|
||||
def _create(query_result=None, all_results=None):
|
||||
mock_session = Mock()
|
||||
mock_result = Mock()
|
||||
|
||||
if all_results is not None:
|
||||
mock_result.scalars.return_value.all.return_value = all_results
|
||||
else:
|
||||
mock_result.scalars.return_value.first.return_value = query_result
|
||||
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.refresh = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_context_manager():
|
||||
yield mock_session
|
||||
|
||||
return mock_context_manager, mock_session
|
||||
|
||||
return _create
|
||||
|
||||
|
||||
class TestJiraIntegrationStoreAsyncMethods:
|
||||
"""Tests verifying JiraIntegrationStore methods use async sessions correctly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_by_id_returns_workspace(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test get_workspace_by_id returns workspace when found."""
|
||||
# Arrange
|
||||
mock_workspace = Mock(spec=JiraWorkspace)
|
||||
mock_workspace.id = 1
|
||||
mock_workspace.name = 'test-workspace'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session(mock_workspace)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
result = await store.get_workspace_by_id(1)
|
||||
|
||||
# Assert
|
||||
assert result == mock_workspace
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_by_id_returns_none_when_not_found(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test get_workspace_by_id returns None when workspace not found."""
|
||||
# Arrange
|
||||
mock_context_manager, mock_session = create_mock_async_session(None)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
result = await store.get_workspace_by_id(999)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_by_name_normalizes_to_lowercase(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test get_workspace_by_name converts name to lowercase for query."""
|
||||
# Arrange
|
||||
mock_workspace = Mock(spec=JiraWorkspace)
|
||||
mock_workspace.name = 'test-workspace'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session(mock_workspace)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
result = await store.get_workspace_by_name('TEST-WORKSPACE')
|
||||
|
||||
# Assert
|
||||
assert result == mock_workspace
|
||||
# Verify the query was executed (filter includes lowercase conversion)
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_user_filters_by_status(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test get_active_user only returns users with active status."""
|
||||
# Arrange
|
||||
mock_user = Mock(spec=JiraUser)
|
||||
mock_user.jira_user_id = 'jira-123'
|
||||
mock_user.jira_workspace_id = 1
|
||||
mock_user.status = 'active'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session(mock_user)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
result = await store.get_active_user('jira-123', 1)
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workspace_adds_and_commits(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test create_workspace properly adds, commits, and refreshes."""
|
||||
# Arrange
|
||||
mock_context_manager, mock_session = create_mock_async_session(None)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
await store.create_workspace(
|
||||
name='TEST-WORKSPACE',
|
||||
jira_cloud_id='cloud-123',
|
||||
admin_user_id='admin-user',
|
||||
encrypted_webhook_secret='encrypted-secret',
|
||||
svc_acc_email='svc@test.com',
|
||||
encrypted_svc_acc_api_key='encrypted-key',
|
||||
status='active',
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
# Verify workspace was created with lowercase name
|
||||
added_workspace = mock_session.add.call_args[0][0]
|
||||
assert added_workspace.name == 'test-workspace'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_integration_status_raises_if_not_found(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test update_user_integration_status raises ValueError if user not found."""
|
||||
# Arrange
|
||||
mock_context_manager, mock_session = create_mock_async_session(None)
|
||||
|
||||
# Act & Assert
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await store.update_user_integration_status('unknown-user', 'inactive')
|
||||
|
||||
assert 'Jira user not found' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_workspace_deactivates_all_users(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test deactivate_workspace sets all users and workspace to inactive."""
|
||||
# Arrange
|
||||
mock_user1 = Mock(spec=JiraUser)
|
||||
mock_user1.status = 'active'
|
||||
mock_user2 = Mock(spec=JiraUser)
|
||||
mock_user2.status = 'active'
|
||||
|
||||
mock_workspace = Mock(spec=JiraWorkspace)
|
||||
mock_workspace.status = 'active'
|
||||
|
||||
mock_session = Mock()
|
||||
|
||||
# First execute returns users, second returns workspace
|
||||
call_count = [0]
|
||||
|
||||
def execute_side_effect(*args, **kwargs):
|
||||
result = Mock()
|
||||
if call_count[0] == 0:
|
||||
result.scalars.return_value.all.return_value = [mock_user1, mock_user2]
|
||||
else:
|
||||
result.scalars.return_value.first.return_value = mock_workspace
|
||||
call_count[0] += 1
|
||||
return result
|
||||
|
||||
mock_session.execute = AsyncMock(side_effect=execute_side_effect)
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_context_manager():
|
||||
yield mock_session
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
await store.deactivate_workspace(1)
|
||||
|
||||
# Assert
|
||||
assert mock_user1.status == 'inactive'
|
||||
assert mock_user2.status == 'inactive'
|
||||
assert mock_workspace.status == 'inactive'
|
||||
mock_session.commit.assert_called_once()
|
||||
@@ -5,15 +5,21 @@ Tests the async database operations for organization app settings.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from server.routes.org_models import OrgAppSettingsUpdate
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_app_settings_store import OrgAppSettingsStore
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from server.routes.org_models import OrgAppSettingsUpdate
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_app_settings_store import OrgAppSettingsStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -8,13 +8,18 @@ import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from server.routes.org_models import OrgLLMSettingsUpdate
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_llm_settings_store import OrgLLMSettingsStore
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from server.routes.org_models import OrgLLMSettingsUpdate
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_llm_settings_store import OrgLLMSettingsStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -5,15 +5,21 @@ Tests the async database operations for user app settings.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from server.routes.user_app_settings_models import UserAppSettingsUpdate
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
from storage.user_app_settings_store import UserAppSettingsStore
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from server.routes.user_app_settings_models import UserAppSettingsUpdate
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
from storage.user_app_settings_store import UserAppSettingsStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,49 +1,40 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = mock_session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Mock user with org_id."""
|
||||
user = MagicMock()
|
||||
user.current_org_id = uuid.uuid4()
|
||||
user.current_org_id = 'test-org-123'
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_store():
|
||||
return ApiKeyStore()
|
||||
def api_key_store(mock_session_maker):
|
||||
return ApiKeyStore(mock_session_maker)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_api():
|
||||
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
|
||||
api_url_patch = patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
|
||||
)
|
||||
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
|
||||
client_patch = patch('httpx.AsyncClient')
|
||||
|
||||
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
|
||||
mock_client.return_value.__aenter__.return_value.post.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value.patch.return_value = (
|
||||
mock_response
|
||||
)
|
||||
yield mock_client
|
||||
def run_sync(func, *args, **kwargs):
|
||||
"""Helper to execute sync functions directly (mocks call_sync_from_async)."""
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def test_generate_api_key(api_key_store):
|
||||
@@ -56,445 +47,294 @@ def test_generate_api_key(api_key_store):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
async def test_create_api_key(
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test creating an API key."""
|
||||
# Setup
|
||||
user_id = str(uuid.uuid4())
|
||||
user_id = 'test-user-123'
|
||||
name = 'Test Key'
|
||||
mock_get_user.return_value = mock_user
|
||||
api_key_store.generate_api_key = MagicMock(return_value='test-api-key')
|
||||
|
||||
# Patch a_session_maker in the api_key_store module to use the test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Execute
|
||||
result = await api_key_store.create_api_key(user_id, name)
|
||||
|
||||
# Verify
|
||||
assert result.startswith('sk-oh-')
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
|
||||
# Verify the ApiKey was created in the database using async session
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.user_id == user_id)
|
||||
)
|
||||
api_key = result_db.scalars().first()
|
||||
assert api_key is not None
|
||||
assert api_key.name == name
|
||||
assert api_key.org_id == mock_user.current_org_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_valid(api_key_store, async_session_maker):
|
||||
"""Test validating a valid API key."""
|
||||
# Setup - create an API key in the database
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
api_key_value = 'test-api-key'
|
||||
|
||||
async with async_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key_value,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='Test Key',
|
||||
expires_at=None,
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Verify
|
||||
assert result == user_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_expired(api_key_store, async_session_maker):
|
||||
"""Test validating an expired API key."""
|
||||
# Setup - create an expired API key in the database
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
api_key_value = 'test-expired-key'
|
||||
|
||||
async with async_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key_value,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='Test Key',
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_expired_timezone_naive(
|
||||
api_key_store, async_session_maker
|
||||
):
|
||||
"""Test validating an expired API key with timezone-naive datetime from database."""
|
||||
# Setup - create an expired API key with timezone-naive datetime
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
api_key_value = 'test-expired-naive-key'
|
||||
|
||||
async with async_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key_value,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='Test Key',
|
||||
# Timezone-naive datetime (database stores this)
|
||||
expires_at=datetime.now() - timedelta(days=1),
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_valid_timezone_naive(
|
||||
api_key_store, async_session_maker
|
||||
):
|
||||
"""Test validating a valid API key with timezone-naive datetime from database."""
|
||||
# Setup - create a valid API key with timezone-naive datetime (future date)
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
api_key_value = 'test-valid-naive-key'
|
||||
|
||||
async with async_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key_value,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='Test Key',
|
||||
# Timezone-naive datetime in the future
|
||||
expires_at=datetime.now() + timedelta(days=1),
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Verify
|
||||
assert result == user_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_not_found(api_key_store, async_session_maker):
|
||||
"""Test validating a non-existent API key."""
|
||||
# Execute
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key('non-existent-key')
|
||||
result = await api_key_store.create_api_key(user_id, name)
|
||||
|
||||
# Verify
|
||||
assert result == 'test-api-key'
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
api_key_store.generate_api_key.assert_called_once()
|
||||
|
||||
# Verify the ApiKey was created with the correct org_id
|
||||
added_api_key = mock_session.add.call_args[0][0]
|
||||
assert added_api_key.org_id == mock_user.current_org_id
|
||||
|
||||
|
||||
def test_validate_api_key_valid(api_key_store, mock_session):
|
||||
"""Test validating a valid API key."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
user_id = 'test-user-123'
|
||||
mock_key_record = MagicMock()
|
||||
mock_key_record.user_id = user_id
|
||||
mock_key_record.expires_at = None
|
||||
mock_key_record.id = 1
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = api_key_store.validate_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result == user_id
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_api_key_expired(api_key_store, mock_session):
|
||||
"""Test validating an expired API key."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
mock_key_record = MagicMock()
|
||||
mock_key_record.expires_at = datetime.now(UTC) - timedelta(days=1)
|
||||
mock_key_record.id = 1
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = api_key_store.validate_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
mock_session.execute.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key(api_key_store, async_session_maker):
|
||||
def test_validate_api_key_expired_timezone_naive(api_key_store, mock_session):
|
||||
"""Test validating an expired API key with timezone-naive datetime from database."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
mock_key_record = MagicMock()
|
||||
# Simulate timezone-naive datetime as returned from database
|
||||
mock_key_record.expires_at = datetime.now() - timedelta(days=1) # No UTC timezone
|
||||
mock_key_record.id = 1
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = api_key_store.validate_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
mock_session.execute.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_api_key_valid_timezone_naive(api_key_store, mock_session):
|
||||
"""Test validating a valid API key with timezone-naive datetime from database."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
user_id = 'test-user-123'
|
||||
mock_key_record = MagicMock()
|
||||
mock_key_record.user_id = user_id
|
||||
# Simulate timezone-naive datetime as returned from database (future date)
|
||||
mock_key_record.expires_at = datetime.now() + timedelta(days=1) # No UTC timezone
|
||||
mock_key_record.id = 1
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = api_key_store.validate_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result == user_id
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_validate_api_key_not_found(api_key_store, mock_session):
|
||||
"""Test validating a non-existent API key."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
query_result = mock_session.query.return_value.filter.return_value
|
||||
query_result.first.return_value = None
|
||||
|
||||
# Execute
|
||||
result = api_key_store.validate_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
mock_session.execute.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_delete_api_key(api_key_store, mock_session):
|
||||
"""Test deleting an API key."""
|
||||
# Setup - create an API key in the database
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
api_key_value = 'test-delete-key'
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
mock_key_record = MagicMock()
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key_value,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='Test Key',
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.delete_api_key(api_key_value)
|
||||
# Execute
|
||||
result = api_key_store.delete_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
|
||||
# Verify it was deleted from the database
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.key == api_key_value)
|
||||
)
|
||||
api_key = result_db.scalars().first()
|
||||
assert api_key is None
|
||||
mock_session.delete.assert_called_once_with(mock_key_record)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_not_found(api_key_store, async_session_maker):
|
||||
def test_delete_api_key_not_found(api_key_store, mock_session):
|
||||
"""Test deleting a non-existent API key."""
|
||||
# Setup
|
||||
api_key = 'test-api-key'
|
||||
query_result = mock_session.query.return_value.filter.return_value
|
||||
query_result.first.return_value = None
|
||||
|
||||
# Execute
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.delete_api_key('non-existent-key')
|
||||
result = api_key_store.delete_api_key(api_key)
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_id(api_key_store, async_session_maker):
|
||||
def test_delete_api_key_by_id(api_key_store, mock_session):
|
||||
"""Test deleting an API key by ID."""
|
||||
# Setup - create an API key in the database
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
# Setup
|
||||
key_id = 123
|
||||
mock_key_record = MagicMock()
|
||||
mock_session.query.return_value.filter.return_value.first.return_value = (
|
||||
mock_key_record
|
||||
)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key='test-delete-by-id-key',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='Test Key',
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
key_id = key_record.id
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.delete_api_key_by_id(key_id)
|
||||
# Execute
|
||||
result = api_key_store.delete_api_key_by_id(key_id)
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
|
||||
# Verify it was deleted from the database
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
api_key = result_db.scalars().first()
|
||||
assert api_key is None
|
||||
mock_session.delete.assert_called_once_with(mock_key_record)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
async def test_list_api_keys(
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test listing API keys for a user."""
|
||||
# Setup
|
||||
user_id = str(uuid.uuid4())
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
now = datetime.now(UTC)
|
||||
mock_key1 = MagicMock()
|
||||
mock_key1.id = 1
|
||||
mock_key1.name = 'Key 1'
|
||||
mock_key1.created_at = now
|
||||
mock_key1.last_used_at = now
|
||||
mock_key1.expires_at = now + timedelta(days=30)
|
||||
|
||||
# Create API keys in the database
|
||||
async with async_session_maker() as session:
|
||||
key1 = ApiKey(
|
||||
key='test-key-1',
|
||||
user_id=user_id,
|
||||
org_id=mock_user.current_org_id,
|
||||
name='Key 1',
|
||||
created_at=now,
|
||||
last_used_at=now,
|
||||
expires_at=now + timedelta(days=30),
|
||||
)
|
||||
key2 = ApiKey(
|
||||
key='test-key-2',
|
||||
user_id=user_id,
|
||||
org_id=mock_user.current_org_id,
|
||||
name='Key 2',
|
||||
created_at=now,
|
||||
last_used_at=None,
|
||||
expires_at=None,
|
||||
)
|
||||
# Add an MCP key that should be filtered out
|
||||
mcp_key = ApiKey(
|
||||
key='test-mcp-key',
|
||||
user_id=user_id,
|
||||
org_id=mock_user.current_org_id,
|
||||
name='MCP_API_KEY',
|
||||
created_at=now,
|
||||
)
|
||||
session.add_all([key1, key2, mcp_key])
|
||||
await session.commit()
|
||||
mock_key2 = MagicMock()
|
||||
mock_key2.id = 2
|
||||
mock_key2.name = 'Key 2'
|
||||
mock_key2.created_at = now
|
||||
mock_key2.last_used_at = None
|
||||
mock_key2.expires_at = None
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.list_api_keys(user_id)
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_key1, mock_key2]
|
||||
|
||||
# Execute
|
||||
result = await api_key_store.list_api_keys(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert len(result) == 2
|
||||
assert result[0].id == 1
|
||||
assert result[0].name == 'Key 1'
|
||||
assert result[0].created_at == now
|
||||
assert result[0].last_used_at == now
|
||||
assert result[0].expires_at == now + timedelta(days=30)
|
||||
|
||||
assert result[1].id == 2
|
||||
assert result[1].name == 'Key 2'
|
||||
assert result[1].created_at == now
|
||||
assert result[1].last_used_at is None
|
||||
assert result[1].expires_at is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
async def test_retrieve_mcp_api_key(
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test retrieving MCP API key for a user."""
|
||||
# Setup
|
||||
user_id = str(uuid.uuid4())
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Create API keys in the database
|
||||
async with async_session_maker() as session:
|
||||
other_key = ApiKey(
|
||||
key='test-other-key',
|
||||
user_id=user_id,
|
||||
org_id=mock_user.current_org_id,
|
||||
name='Other Key',
|
||||
created_at=now,
|
||||
)
|
||||
mcp_key = ApiKey(
|
||||
key='test-mcp-key',
|
||||
user_id=user_id,
|
||||
org_id=mock_user.current_org_id,
|
||||
name='MCP_API_KEY',
|
||||
created_at=now,
|
||||
)
|
||||
session.add_all([other_key, mcp_key])
|
||||
await session.commit()
|
||||
mock_mcp_key = MagicMock()
|
||||
mock_mcp_key.name = 'MCP_API_KEY'
|
||||
mock_mcp_key.key = 'mcp-test-key'
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.retrieve_mcp_api_key(user_id)
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key]
|
||||
|
||||
# Execute
|
||||
result = await api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result == 'test-mcp-key'
|
||||
assert result == 'mcp-test-key'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
async def test_retrieve_mcp_api_key_not_found(
|
||||
mock_get_user, api_key_store, async_session_maker, mock_user
|
||||
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test retrieving MCP API key when none exists."""
|
||||
# Setup
|
||||
user_id = str(uuid.uuid4())
|
||||
user_id = 'test-user-123'
|
||||
mock_get_user.return_value = mock_user
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Create only non-MCP keys in the database
|
||||
async with async_session_maker() as session:
|
||||
other_key = ApiKey(
|
||||
key='test-other-key',
|
||||
user_id=user_id,
|
||||
org_id=mock_user.current_org_id,
|
||||
name='Other Key',
|
||||
created_at=now,
|
||||
)
|
||||
session.add(other_key)
|
||||
await session.commit()
|
||||
mock_other_key = MagicMock()
|
||||
mock_other_key.name = 'Other Key'
|
||||
mock_other_key.key = 'other-test-key'
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.retrieve_mcp_api_key(user_id)
|
||||
# Mock the chained query calls for filtering by user_id and org_id
|
||||
mock_query = mock_session.query.return_value
|
||||
mock_filter_user = mock_query.filter.return_value
|
||||
mock_filter_org = mock_filter_user.filter.return_value
|
||||
mock_filter_org.all.return_value = [mock_other_key]
|
||||
|
||||
# Execute
|
||||
result = await api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_api_key_by_name(api_key_store, async_session_maker):
|
||||
"""Test retrieving an API key by name."""
|
||||
# Setup
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
key_name = 'Test Key'
|
||||
key_value = 'test-key-by-name'
|
||||
|
||||
async with async_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=key_value,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.retrieve_api_key_by_name(user_id, key_name)
|
||||
|
||||
# Verify
|
||||
assert result == key_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_api_key_by_name_not_found(api_key_store, async_session_maker):
|
||||
"""Test retrieving an API key by name that doesn't exist."""
|
||||
# Execute
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.retrieve_api_key_by_name(
|
||||
'non-existent-user', 'Non Existent Key'
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_name(api_key_store, async_session_maker):
|
||||
"""Test deleting an API key by name."""
|
||||
# Setup
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
key_name = 'Test Key to Delete'
|
||||
key_value = 'test-delete-by-name'
|
||||
|
||||
async with async_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=key_value,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.delete_api_key_by_name(user_id, key_name)
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
|
||||
# Verify it was deleted from the database
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.key == key_value)
|
||||
)
|
||||
api_key = result_db.scalars().first()
|
||||
assert api_key is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_name_not_found(api_key_store, async_session_maker):
|
||||
"""Test deleting an API key by name that doesn't exist."""
|
||||
# Execute
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.delete_api_key_by_name(
|
||||
'non-existent-user', 'Non Existent Key'
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
@@ -595,7 +595,7 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=True)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
@@ -621,7 +621,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
@@ -660,7 +660,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -686,7 +686,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
@@ -725,7 +725,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -749,7 +749,7 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
@@ -898,7 +898,7 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
@@ -959,7 +959,7 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
@@ -1022,7 +1022,7 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
):
|
||||
# Arrange
|
||||
@@ -1174,7 +1174,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
@@ -1221,7 +1221,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
mock_recaptcha_service.create_assessment.return_value = (
|
||||
@@ -1284,7 +1284,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
mock_recaptcha_service.create_assessment.return_value = (
|
||||
@@ -1325,7 +1325,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
@@ -1371,7 +1371,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
mock_recaptcha_service.create_assessment.return_value = (
|
||||
@@ -1414,7 +1414,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
@@ -1460,7 +1460,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
mock_recaptcha_service.create_assessment.return_value = (
|
||||
@@ -1500,7 +1500,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
@@ -1546,7 +1546,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
mock_recaptcha_service.create_assessment.return_value = (
|
||||
@@ -1585,7 +1585,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
@@ -1631,7 +1631,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
mock_recaptcha_service.create_assessment.return_value = (
|
||||
@@ -1666,7 +1666,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', ''),
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
@@ -1713,7 +1713,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
await keycloak_callback(
|
||||
@@ -1734,7 +1734,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
@@ -1781,7 +1781,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
await keycloak_callback(code='test_code', state=state, request=mock_request)
|
||||
@@ -1808,7 +1808,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
patch('server.routes.auth.recaptcha_service') as mock_recaptcha_service,
|
||||
patch('server.routes.auth.RECAPTCHA_SITE_KEY', 'test-site-key'),
|
||||
patch('server.routes.auth.user_verifier') as mock_verifier,
|
||||
patch('server.routes.auth.a_session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.session_maker') as mock_session_maker,
|
||||
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch('server.routes.auth.posthog'),
|
||||
@@ -1855,7 +1855,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
mock_recaptcha_service.create_assessment.side_effect = Exception(
|
||||
'Service error'
|
||||
@@ -1924,7 +1924,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_domain_blocked = AsyncMock(return_value=False)
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
# Patch the module-level recaptcha_service instance
|
||||
mock_recaptcha_service.create_assessment.return_value = (
|
||||
|
||||
@@ -6,7 +6,6 @@ import pytest
|
||||
import stripe
|
||||
from fastapi import HTTPException, Request, status
|
||||
from httpx import Response
|
||||
from server.constants import ORG_SETTINGS_VERSION
|
||||
from server.routes import billing
|
||||
from server.routes.billing import (
|
||||
CreateBillingSessionResponse,
|
||||
@@ -19,11 +18,22 @@ from server.routes.billing import (
|
||||
has_payment_method,
|
||||
success_callback,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
from storage.stripe_customer import Base as StripeCustomerBase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
StripeCustomerBase.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -66,38 +76,6 @@ def mock_subscription_request():
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_org(async_session_maker):
|
||||
"""Create a test org in the database."""
|
||||
org_id = uuid.uuid4()
|
||||
async with async_session_maker() as session:
|
||||
org = Org(
|
||||
id=org_id,
|
||||
name=f'test-org-{org_id}',
|
||||
org_version=ORG_SETTINGS_VERSION,
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
return org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(async_session_maker, test_org):
|
||||
"""Create a test user in the database linked to test_org."""
|
||||
user_id = uuid.uuid4()
|
||||
async with async_session_maker() as session:
|
||||
user = User(
|
||||
id=user_id,
|
||||
current_org_id=test_org.id,
|
||||
user_consents_to_analytics=True,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_lite_llm_error():
|
||||
with (
|
||||
@@ -155,14 +133,17 @@ async def test_get_credits_success():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkout_session_stripe_error(
|
||||
async_session_maker, mock_checkout_request, test_org
|
||||
session_maker, mock_checkout_request
|
||||
):
|
||||
"""Test handling of Stripe API errors."""
|
||||
|
||||
mock_customer = stripe.Customer(
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.id = uuid.uuid4()
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
pytest.raises(Exception, match='Stripe API Error'),
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
@@ -173,13 +154,10 @@ async def test_create_checkout_session_stripe_error(
|
||||
'stripe.checkout.Session.create_async',
|
||||
AsyncMock(side_effect=Exception('Stripe API Error')),
|
||||
),
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('integrations.stripe_service.a_session_maker', async_session_maker),
|
||||
patch('storage.database.a_session_maker', async_session_maker),
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=test_org,
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
@@ -193,27 +171,44 @@ async def test_create_checkout_session_stripe_error(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkout_session_success(
|
||||
async_session_maker, mock_checkout_request, test_org
|
||||
):
|
||||
async def test_create_checkout_session_success(session_maker, mock_checkout_request):
|
||||
"""Test successful creation of checkout session."""
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.url = 'https://checkout.stripe.com/test-session'
|
||||
mock_session.id = 'test_session_id_checkout'
|
||||
mock_session.id = 'test_session_id'
|
||||
mock_create = AsyncMock(return_value=mock_session)
|
||||
mock_create.return_value = mock_session
|
||||
|
||||
mock_customer_info = {'customer_id': 'mock-customer', 'org_id': test_org.id}
|
||||
|
||||
mock_customer = stripe.Customer(
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
)
|
||||
mock_customer_create = AsyncMock(return_value=mock_customer)
|
||||
mock_org = MagicMock()
|
||||
mock_org_id = uuid.uuid4()
|
||||
mock_org.id = mock_org_id
|
||||
mock_org.contact_email = 'testy@tester.com'
|
||||
with (
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('integrations.stripe_service.a_session_maker', async_session_maker),
|
||||
patch('stripe.Customer.create_async', mock_customer_create),
|
||||
patch(
|
||||
'integrations.stripe_service.find_or_create_customer_by_user_id',
|
||||
AsyncMock(return_value=mock_customer_info),
|
||||
'stripe.Customer.search_async', AsyncMock(return_value=MagicMock(data=[]))
|
||||
),
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('integrations.stripe_service.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('server.routes.billing.validate_billing_enabled'),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
result = await create_checkout_session(
|
||||
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
|
||||
)
|
||||
@@ -245,102 +240,74 @@ async def test_create_checkout_session_success(
|
||||
cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
|
||||
)
|
||||
|
||||
# Verify database record was created
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(BillingSession).where(
|
||||
BillingSession.id == 'test_session_id_checkout'
|
||||
)
|
||||
)
|
||||
billing_session = result_db.scalar_one_or_none()
|
||||
assert billing_session is not None
|
||||
assert billing_session.user_id == 'mock_user'
|
||||
assert billing_session.org_id == test_org.id
|
||||
assert billing_session.status == 'in_progress'
|
||||
assert float(billing_session.price) == 25.0
|
||||
# Verify database session creation
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_session_not_found(async_session_maker):
|
||||
async def test_success_callback_session_not_found():
|
||||
"""Test success callback when billing session is not found."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('stripe.checkout.Session.retrieve'),
|
||||
):
|
||||
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await success_callback('nonexistent_session_id', mock_request)
|
||||
await success_callback('test_session_id', mock_request)
|
||||
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_stripe_incomplete(
|
||||
async_session_maker, test_org, test_user
|
||||
):
|
||||
async def test_success_callback_stripe_incomplete():
|
||||
"""Test success callback when Stripe session is not complete."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
session_id = 'test_incomplete_session'
|
||||
async with async_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=session_id,
|
||||
user_id=str(test_user.id),
|
||||
org_id=test_org.id,
|
||||
status='in_progress',
|
||||
price=25,
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
await session.commit()
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(status='pending')
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await success_callback(session_id, mock_request)
|
||||
await success_callback('test_session_id', mock_request)
|
||||
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
# Verify no database update occurred
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(BillingSession.id == session_id)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
assert billing_session.status == 'in_progress'
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_success(async_session_maker, test_org, test_user):
|
||||
async def test_success_callback_success():
|
||||
"""Test successful payment completion and credit update."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
session_id = 'test_success_session'
|
||||
async with async_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=session_id,
|
||||
user_id=str(test_user.id),
|
||||
org_id=test_org.id,
|
||||
status='in_progress',
|
||||
price=25,
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
await session.commit()
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
mock_org = MagicMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id=test_org.id),
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
@@ -353,11 +320,25 @@ async def test_success_callback_success(async_session_maker, test_org, test_user
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
|
||||
) as mock_update_budget,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
# First query: BillingSession (query().filter().filter().first())
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
# Second query: Org (query().filter().first()) - use side_effect for different return chains
|
||||
mock_query_chain_billing = MagicMock()
|
||||
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_query_chain_org = MagicMock()
|
||||
mock_query_chain_org.filter.return_value.first.return_value = mock_org
|
||||
mock_db_session.query.side_effect = [
|
||||
mock_query_chain_billing,
|
||||
mock_query_chain_org,
|
||||
]
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete', amount_subtotal=2500, customer='mock_customer_id'
|
||||
)
|
||||
) # $25.00 in cents
|
||||
|
||||
response = await success_callback(session_id, mock_request)
|
||||
response = await success_callback('test_session_id', mock_request)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
@@ -365,80 +346,64 @@ async def test_success_callback_success(async_session_maker, test_org, test_user
|
||||
== 'https://test.com/settings/billing?checkout=success'
|
||||
)
|
||||
|
||||
# Verify LiteLLM API calls
|
||||
mock_update_budget.assert_called_once_with(
|
||||
str(test_org.id),
|
||||
'mock_org_id',
|
||||
125.0, # 100 + 25.00
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(BillingSession.id == session_id)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
assert billing_session.status == 'completed'
|
||||
assert float(billing_session.price) == 25.0
|
||||
# Verify BYOR export is enabled for the org (updated in same session)
|
||||
assert mock_org.byor_export_enabled is True
|
||||
|
||||
# Verify org byor_export_enabled was set
|
||||
org_result = await session.execute(select(Org).where(Org.id == test_org.id))
|
||||
org = org_result.scalar_one_or_none()
|
||||
assert org.byor_export_enabled is True
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'completed'
|
||||
assert mock_billing_session.price == 25.0
|
||||
mock_db_session.merge.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_lite_llm_error(
|
||||
async_session_maker, test_org, test_user
|
||||
):
|
||||
async def test_success_callback_lite_llm_error():
|
||||
"""Test handling of LiteLLM API errors during success callback."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
session_id = 'test_litellm_error_session'
|
||||
async with async_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=session_id,
|
||||
user_id=str(test_user.id),
|
||||
org_id=test_org.id,
|
||||
status='in_progress',
|
||||
price=25,
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
await session.commit()
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id=test_org.id),
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete', amount_subtotal=2500
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback(session_id, mock_request)
|
||||
await success_callback('test_session_id', mock_request)
|
||||
|
||||
# Verify no database updates occurred (transaction rolled back)
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(BillingSession.id == session_id)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
assert billing_session.status == 'in_progress'
|
||||
# Verify no database updates occurred
|
||||
assert mock_billing_session.status == 'in_progress'
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_lite_llm_update_budget_error_rollback(
|
||||
async_session_maker, test_org, test_user
|
||||
):
|
||||
async def test_success_callback_lite_llm_update_budget_error_rollback():
|
||||
"""Test that database changes are not committed when update_team_and_users_budget fails.
|
||||
|
||||
This test verifies that if LiteLlmManager.update_team_and_users_budget raises an exception,
|
||||
@@ -447,26 +412,19 @@ async def test_success_callback_lite_llm_update_budget_error_rollback(
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
session_id = 'test_budget_rollback_session'
|
||||
async with async_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=session_id,
|
||||
user_id=str(test_user.id),
|
||||
org_id=test_org.id,
|
||||
status='in_progress',
|
||||
price=10,
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
await session.commit()
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
mock_org = MagicMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id=test_org.id),
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
@@ -480,60 +438,70 @@ async def test_success_callback_lite_llm_update_budget_error_rollback(
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_query_chain_billing = MagicMock()
|
||||
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_query_chain_org = MagicMock()
|
||||
mock_query_chain_org.filter.return_value.first.return_value = mock_org
|
||||
mock_db_session.query.side_effect = [
|
||||
mock_query_chain_billing,
|
||||
mock_query_chain_org,
|
||||
]
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete',
|
||||
amount_subtotal=1000,
|
||||
amount_subtotal=1000, # $10
|
||||
customer='mock_customer_id',
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback(session_id, mock_request)
|
||||
await success_callback('test_session_id', mock_request)
|
||||
|
||||
# Verify no database commit occurred - the transaction should roll back
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(BillingSession.id == session_id)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
assert billing_session.status == 'in_progress'
|
||||
# Verify no database commit occurred - the transaction should roll back
|
||||
assert mock_billing_session.status == 'in_progress'
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_callback_session_not_found(async_session_maker):
|
||||
async def test_cancel_callback_session_not_found():
|
||||
"""Test cancel callback when billing session is not found."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
with patch('server.routes.billing.a_session_maker', async_session_maker):
|
||||
response = await cancel_callback('nonexistent_session_id', mock_request)
|
||||
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = None
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
response = await cancel_callback('test_session_id', mock_request)
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'https://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify no database updates occurred
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_callback_success(async_session_maker, test_org, test_user):
|
||||
async def test_cancel_callback_success():
|
||||
"""Test successful cancellation of billing session."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
session_id = 'test_cancel_session'
|
||||
async with async_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=session_id,
|
||||
user_id=str(test_user.id),
|
||||
org_id=test_org.id,
|
||||
status='in_progress',
|
||||
price=25,
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
await session.commit()
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
|
||||
with patch('server.routes.billing.a_session_maker', async_session_maker):
|
||||
response = await cancel_callback(session_id, mock_request)
|
||||
with patch('server.routes.billing.session_maker') as mock_session_maker:
|
||||
mock_db_session = MagicMock()
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
response = await cancel_callback('test_session_id', mock_request)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
@@ -541,18 +509,16 @@ async def test_cancel_callback_success(async_session_maker, test_org, test_user)
|
||||
== 'https://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify database update
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(BillingSession.id == session_id)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
assert billing_session.status == 'cancelled'
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'cancelled'
|
||||
mock_db_session.merge.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_with_payment_method():
|
||||
"""Test has_payment_method returns True when user has a payment method."""
|
||||
|
||||
mock_has_payment_method = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Unit tests for DomainBlocker class."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from server.auth.domain_blocker import DomainBlocker
|
||||
@@ -9,9 +9,7 @@ from server.auth.domain_blocker import DomainBlocker
|
||||
@pytest.fixture
|
||||
def mock_store():
|
||||
"""Create a mock BlockedEmailDomainStore for testing."""
|
||||
store = MagicMock()
|
||||
store.is_domain_blocked = AsyncMock()
|
||||
return store
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -59,120 +57,109 @@ def test_extract_domain_invalid_emails(domain_blocker, email, expected):
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_with_none_email(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_with_none_email(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when email is None."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked(None)
|
||||
result = domain_blocker.is_domain_blocked(None)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_with_empty_email(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_with_empty_email(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when email is empty."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('')
|
||||
result = domain_blocker.is_domain_blocked('')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_with_invalid_email(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_with_invalid_email(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when email format is invalid."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('invalid-email')
|
||||
result = domain_blocker.is_domain_blocked('invalid-email')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_not_blocked(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_domain_not_blocked(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when domain is not blocked."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@example.com')
|
||||
result = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('example.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_blocked(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_domain_blocked(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns True when domain is blocked."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
result = domain_blocker.is_domain_blocked('user@colsch.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_case_insensitive(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_case_insensitive(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked performs case-insensitive domain extraction."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@COLSCH.US')
|
||||
result = domain_blocker.is_domain_blocked('user@COLSCH.US')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_with_whitespace(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_with_whitespace(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked handles emails with whitespace correctly."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked(' user@colsch.us ')
|
||||
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('colsch.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_multiple_blocked_domains(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked correctly checks multiple domains."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked = AsyncMock(
|
||||
side_effect=lambda domain: domain
|
||||
in [
|
||||
'other-domain.com',
|
||||
'blocked.org',
|
||||
]
|
||||
)
|
||||
mock_store.is_domain_blocked.side_effect = lambda domain: domain in [
|
||||
'other-domain.com',
|
||||
'blocked.org',
|
||||
]
|
||||
|
||||
# Act
|
||||
result1 = await domain_blocker.is_domain_blocked('user@other-domain.com')
|
||||
result2 = await domain_blocker.is_domain_blocked('user@blocked.org')
|
||||
result3 = await domain_blocker.is_domain_blocked('user@allowed.com')
|
||||
result1 = domain_blocker.is_domain_blocked('user@other-domain.com')
|
||||
result2 = domain_blocker.is_domain_blocked('user@blocked.org')
|
||||
result3 = domain_blocker.is_domain_blocked('user@allowed.com')
|
||||
|
||||
# Assert
|
||||
assert result1 is True
|
||||
@@ -181,8 +168,7 @@ async def test_is_domain_blocked_multiple_blocked_domains(domain_blocker, mock_s
|
||||
assert mock_store.is_domain_blocked.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_tld_pattern_blocks_matching_domain(
|
||||
def test_is_domain_blocked_tld_pattern_blocks_matching_domain(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that TLD pattern blocks domains ending with that TLD."""
|
||||
@@ -190,15 +176,14 @@ async def test_is_domain_blocked_tld_pattern_blocks_matching_domain(
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@company.us')
|
||||
result = domain_blocker.is_domain_blocked('user@company.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('company.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(
|
||||
def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that TLD pattern blocks subdomains with that TLD."""
|
||||
@@ -206,15 +191,14 @@ async def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@subdomain.company.us')
|
||||
result = domain_blocker.is_domain_blocked('user@subdomain.company.us')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('subdomain.company.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(
|
||||
def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that TLD pattern does not block domains with different TLD."""
|
||||
@@ -222,41 +206,35 @@ async def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@company.com')
|
||||
result = domain_blocker.is_domain_blocked('user@company.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('company.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_tld_pattern_case_insensitive(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
def test_is_domain_blocked_tld_pattern_case_insensitive(domain_blocker, mock_store):
|
||||
"""Test that TLD pattern matching is case-insensitive."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@COMPANY.US')
|
||||
result = domain_blocker.is_domain_blocked('user@COMPANY.US')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('company.us')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_tld_pattern_with_multi_level_tld(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker, mock_store):
|
||||
"""Test that TLD pattern works with multi-level TLDs like .co.uk."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.side_effect = lambda domain: domain.endswith('.co.uk')
|
||||
|
||||
# Act
|
||||
result_match = await domain_blocker.is_domain_blocked('user@example.co.uk')
|
||||
result_subdomain = await domain_blocker.is_domain_blocked('user@api.example.co.uk')
|
||||
result_no_match = await domain_blocker.is_domain_blocked('user@example.uk')
|
||||
result_match = domain_blocker.is_domain_blocked('user@example.co.uk')
|
||||
result_subdomain = domain_blocker.is_domain_blocked('user@api.example.co.uk')
|
||||
result_no_match = domain_blocker.is_domain_blocked('user@example.uk')
|
||||
|
||||
# Assert
|
||||
assert result_match is True
|
||||
@@ -264,8 +242,7 @@ async def test_is_domain_blocked_tld_pattern_with_multi_level_tld(
|
||||
assert result_no_match is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_pattern_blocks_exact_match(
|
||||
def test_is_domain_blocked_domain_pattern_blocks_exact_match(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern blocks exact domain match."""
|
||||
@@ -273,31 +250,27 @@ async def test_is_domain_blocked_domain_pattern_blocks_exact_match(
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@example.com')
|
||||
result = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('example.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_pattern_blocks_subdomain(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
def test_is_domain_blocked_domain_pattern_blocks_subdomain(domain_blocker, mock_store):
|
||||
"""Test that domain pattern blocks subdomains of that domain."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@subdomain.example.com')
|
||||
result = domain_blocker.is_domain_blocked('user@subdomain.example.com')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('subdomain.example.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain(
|
||||
def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern blocks multi-level subdomains."""
|
||||
@@ -305,15 +278,14 @@ async def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain(
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@api.v2.example.com')
|
||||
result = domain_blocker.is_domain_blocked('user@api.v2.example.com')
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_store.is_domain_blocked.assert_called_once_with('api.v2.example.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain(
|
||||
def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern does not block domains that contain but don't match the pattern."""
|
||||
@@ -321,15 +293,14 @@ async def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain(
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@notexample.com')
|
||||
result = domain_blocker.is_domain_blocked('user@notexample.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('notexample.com')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_pattern_does_not_block_different_tld(
|
||||
def test_is_domain_blocked_domain_pattern_does_not_block_different_tld(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that domain pattern does not block same domain with different TLD."""
|
||||
@@ -337,15 +308,14 @@ async def test_is_domain_blocked_domain_pattern_does_not_block_different_tld(
|
||||
mock_store.is_domain_blocked.return_value = False
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@example.org')
|
||||
result = domain_blocker.is_domain_blocked('user@example.org')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
mock_store.is_domain_blocked.assert_called_once_with('example.org')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
|
||||
def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
|
||||
domain_blocker, mock_store
|
||||
):
|
||||
"""Test that blocking a subdomain also blocks its nested subdomains."""
|
||||
@@ -355,9 +325,9 @@ async def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
|
||||
)
|
||||
|
||||
# Act
|
||||
result_exact = await domain_blocker.is_domain_blocked('user@api.example.com')
|
||||
result_nested = await domain_blocker.is_domain_blocked('user@v1.api.example.com')
|
||||
result_parent = await domain_blocker.is_domain_blocked('user@example.com')
|
||||
result_exact = domain_blocker.is_domain_blocked('user@api.example.com')
|
||||
result_nested = domain_blocker.is_domain_blocked('user@v1.api.example.com')
|
||||
result_parent = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result_exact is True
|
||||
@@ -365,15 +335,14 @@ async def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(
|
||||
assert result_parent is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_with_hyphens(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_domain_with_hyphens(domain_blocker, mock_store):
|
||||
"""Test that domain patterns work with hyphenated domains."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result_exact = await domain_blocker.is_domain_blocked('user@my-company.com')
|
||||
result_subdomain = await domain_blocker.is_domain_blocked('user@api.my-company.com')
|
||||
result_exact = domain_blocker.is_domain_blocked('user@my-company.com')
|
||||
result_subdomain = domain_blocker.is_domain_blocked('user@api.my-company.com')
|
||||
|
||||
# Assert
|
||||
assert result_exact is True
|
||||
@@ -381,15 +350,14 @@ async def test_is_domain_blocked_domain_with_hyphens(domain_blocker, mock_store)
|
||||
assert mock_store.is_domain_blocked.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_domain_with_numbers(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_domain_with_numbers(domain_blocker, mock_store):
|
||||
"""Test that domain patterns work with numeric domains."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result_exact = await domain_blocker.is_domain_blocked('user@test123.com')
|
||||
result_subdomain = await domain_blocker.is_domain_blocked('user@api.test123.com')
|
||||
result_exact = domain_blocker.is_domain_blocked('user@test123.com')
|
||||
result_subdomain = domain_blocker.is_domain_blocked('user@api.test123.com')
|
||||
|
||||
# Assert
|
||||
assert result_exact is True
|
||||
@@ -397,14 +365,13 @@ async def test_is_domain_blocked_domain_with_numbers(domain_blocker, mock_store)
|
||||
assert mock_store.is_domain_blocked.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker, mock_store):
|
||||
"""Test that blocking works with very long subdomain chains."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.return_value = True
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked(
|
||||
result = domain_blocker.is_domain_blocked(
|
||||
'user@level4.level3.level2.level1.example.com'
|
||||
)
|
||||
|
||||
@@ -415,14 +382,13 @@ async def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker, mock_
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_domain_blocked_handles_store_exception(domain_blocker, mock_store):
|
||||
def test_is_domain_blocked_handles_store_exception(domain_blocker, mock_store):
|
||||
"""Test that is_domain_blocked returns False when store raises an exception."""
|
||||
# Arrange
|
||||
mock_store.is_domain_blocked.side_effect = Exception('Database connection error')
|
||||
|
||||
# Act
|
||||
result = await domain_blocker.is_domain_blocked('user@example.com')
|
||||
result = domain_blocker.is_domain_blocked('user@example.com')
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
@@ -28,7 +27,6 @@ async def test_submit_feedback():
|
||||
"""Test submitting feedback for a conversation."""
|
||||
# Create a mock database session
|
||||
mock_session = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
# Test data
|
||||
feedback_data = FeedbackRequest(
|
||||
@@ -39,13 +37,19 @@ async def test_submit_feedback():
|
||||
metadata={'browser': 'Chrome', 'os': 'Windows'},
|
||||
)
|
||||
|
||||
# Create async context manager for a_session_maker
|
||||
@asynccontextmanager
|
||||
async def mock_a_session_maker():
|
||||
yield mock_session
|
||||
# Mock session_maker and call_sync_from_async
|
||||
with patch('server.routes.feedback.session_maker') as mock_session_maker, patch(
|
||||
'server.routes.feedback.call_sync_from_async'
|
||||
) as mock_call_sync:
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
|
||||
# Mock call_sync_from_async to execute the function
|
||||
def mock_call_sync_side_effect(func):
|
||||
return func()
|
||||
|
||||
mock_call_sync.side_effect = mock_call_sync_side_effect
|
||||
|
||||
# Mock a_session_maker
|
||||
with patch('server.routes.feedback.a_session_maker', mock_a_session_maker):
|
||||
# Call the function
|
||||
result = await submit_conversation_feedback(feedback_data)
|
||||
|
||||
@@ -74,7 +78,6 @@ async def test_invalid_rating():
|
||||
"""Test submitting feedback with an invalid rating."""
|
||||
# Create a mock database session
|
||||
mock_session = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
# Since Pydantic validation happens before our function is called,
|
||||
# we need to patch the validation to test our function's validation
|
||||
@@ -92,13 +95,14 @@ async def test_invalid_rating():
|
||||
# Mock the validation to return our object
|
||||
mock_validate.return_value = feedback_data
|
||||
|
||||
# Create async context manager for a_session_maker
|
||||
@asynccontextmanager
|
||||
async def mock_a_session_maker():
|
||||
yield mock_session
|
||||
# Mock session_maker and call_sync_from_async
|
||||
with patch('server.routes.feedback.session_maker') as mock_session_maker, patch(
|
||||
'server.routes.feedback.call_sync_from_async'
|
||||
) as mock_call_sync:
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
mock_call_sync.return_value = None
|
||||
|
||||
# Mock a_session_maker
|
||||
with patch('server.routes.feedback.a_session_maker', mock_a_session_maker):
|
||||
# Call the function and expect an exception
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
await submit_conversation_feedback(feedback_data)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Tests for the GitlabCallbackProcessor.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.gitlab.gitlab_view import GitlabIssueComment
|
||||
@@ -111,15 +111,20 @@ class TestGitlabCallbackProcessor:
|
||||
@patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.conversation_manager'
|
||||
)
|
||||
@patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.session_maker'
|
||||
)
|
||||
async def test_call_with_send_summary_instruction(
|
||||
self,
|
||||
mock_session_maker,
|
||||
mock_conversation_manager,
|
||||
mock_get_summary_instruction,
|
||||
async_session_maker,
|
||||
gitlab_callback_processor,
|
||||
):
|
||||
"""Test the __call__ method when send_summary_instruction is True."""
|
||||
# Setup mocks
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_conversation_manager.send_event_to_conversation = AsyncMock()
|
||||
mock_get_summary_instruction.return_value = (
|
||||
"I'm a man of few words. Any questions?"
|
||||
@@ -137,17 +142,15 @@ class TestGitlabCallbackProcessor:
|
||||
)
|
||||
|
||||
# Call the processor
|
||||
with patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.a_session_maker',
|
||||
async_session_maker,
|
||||
):
|
||||
await gitlab_callback_processor(callback, observation)
|
||||
await gitlab_callback_processor(callback, observation)
|
||||
|
||||
# Verify that send_event_to_conversation was called
|
||||
mock_conversation_manager.send_event_to_conversation.assert_called_once()
|
||||
|
||||
# Verify that the processor state was updated
|
||||
assert gitlab_callback_processor.send_summary_instruction is False
|
||||
mock_session.merge.assert_called_once_with(callback)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
@@ -159,16 +162,21 @@ class TestGitlabCallbackProcessor:
|
||||
@patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.asyncio.create_task'
|
||||
)
|
||||
@patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.session_maker'
|
||||
)
|
||||
async def test_call_with_extract_summary(
|
||||
self,
|
||||
mock_session_maker,
|
||||
mock_create_task,
|
||||
mock_extract_summary,
|
||||
mock_conversation_manager,
|
||||
async_session_maker,
|
||||
gitlab_callback_processor,
|
||||
):
|
||||
"""Test the __call__ method when send_summary_instruction is False."""
|
||||
# Setup mocks
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_extract_summary.return_value = 'Test summary'
|
||||
# Ensure we don't leak an un-awaited coroutine when create_task is mocked
|
||||
mock_create_task.side_effect = lambda coro: (coro.close(), None)[1]
|
||||
@@ -188,22 +196,20 @@ class TestGitlabCallbackProcessor:
|
||||
)
|
||||
|
||||
# Call the processor
|
||||
with patch(
|
||||
'server.conversation_callback_processor.gitlab_callback_processor.a_session_maker',
|
||||
async_session_maker,
|
||||
):
|
||||
await gitlab_callback_processor(callback, observation)
|
||||
await gitlab_callback_processor(callback, observation)
|
||||
|
||||
# Verify that extract_summary_from_conversation_manager was called
|
||||
mock_extract_summary.assert_called_once_with(
|
||||
mock_conversation_manager, 'conv123'
|
||||
)
|
||||
|
||||
# Verify that create_task was called at least once to send the message
|
||||
assert mock_create_task.call_count >= 1
|
||||
# Verify that create_task was called to send the message
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
# Verify that the callback status was updated
|
||||
assert callback.status == CallbackStatus.COMPLETED
|
||||
mock_session.merge.assert_called_once_with(callback)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_with_non_terminal_state(self, gitlab_callback_processor):
|
||||
|
||||
@@ -1,54 +1,56 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from server.auth.token_manager import TokenManager
|
||||
from storage.offline_token_store import OfflineTokenStore
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
return None # Not used in tests
|
||||
return MagicMock(spec=OpenHandsConfig)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_store(session_maker, mock_config):
|
||||
return OfflineTokenStore('test_user_id', session_maker, mock_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token_manager():
|
||||
with patch('server.config.get_config') as mock_get_config:
|
||||
mock_config = mock_get_config.return_value
|
||||
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
||||
return TokenManager(external=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_token_new_record(async_session_maker, mock_config):
|
||||
# Setup - inject the test session maker into the store module
|
||||
import storage.offline_token_store as store_module
|
||||
|
||||
store_module.a_session_maker = async_session_maker
|
||||
|
||||
token_store = OfflineTokenStore('test_user_id', mock_config)
|
||||
async def test_store_token_new_record(token_store, session_maker):
|
||||
# Setup
|
||||
test_token = 'test_offline_token'
|
||||
|
||||
# Execute
|
||||
await token_store.store_token(test_token)
|
||||
|
||||
# Verify - use a new session to query
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredOfflineToken).where(
|
||||
StoredOfflineToken.user_id == 'test_user_id'
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
assert record is not None
|
||||
assert record.user_id == 'test_user_id'
|
||||
assert record.offline_token == test_token
|
||||
# Verify
|
||||
with session_maker() as session:
|
||||
query = session.query(StoredOfflineToken)
|
||||
assert query.count() == 1
|
||||
added_record = query.first()
|
||||
assert added_record.user_id == 'test_user_id'
|
||||
assert added_record.offline_token == test_token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_token_existing_record(async_session_maker, mock_config):
|
||||
# Setup - inject the test session maker into the store module
|
||||
import storage.offline_token_store as store_module
|
||||
|
||||
store_module.a_session_maker = async_session_maker
|
||||
|
||||
token_store = OfflineTokenStore('test_user_id', mock_config)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
async def test_store_token_existing_record(token_store, session_maker):
|
||||
# Setup
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
StoredOfflineToken(user_id='test_user_id', offline_token='old_token')
|
||||
)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
test_token = 'new_offline_token'
|
||||
|
||||
@@ -56,35 +58,24 @@ async def test_store_token_existing_record(async_session_maker, mock_config):
|
||||
await token_store.store_token(test_token)
|
||||
|
||||
# Verify
|
||||
async with async_session_maker() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await session.execute(
|
||||
select(StoredOfflineToken).where(
|
||||
StoredOfflineToken.user_id == 'test_user_id'
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
assert record is not None
|
||||
assert record.offline_token == test_token
|
||||
with session_maker() as session:
|
||||
query = session.query(StoredOfflineToken)
|
||||
assert query.count() == 1
|
||||
added_record = query.first()
|
||||
assert added_record.user_id == 'test_user_id'
|
||||
assert added_record.offline_token == test_token
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_token_existing(async_session_maker, mock_config):
|
||||
# Setup - inject the test session maker into the store module
|
||||
import storage.offline_token_store as store_module
|
||||
|
||||
store_module.a_session_maker = async_session_maker
|
||||
|
||||
token_store = OfflineTokenStore('test_user_id', mock_config)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
async def test_load_token_existing(token_store, session_maker):
|
||||
# Setup
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
StoredOfflineToken(
|
||||
user_id='test_user_id', offline_token='test_offline_token'
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
# Execute
|
||||
result = await token_store.load_token()
|
||||
@@ -94,14 +85,7 @@ async def test_load_token_existing(async_session_maker, mock_config):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_token_not_found(async_session_maker, mock_config):
|
||||
# Setup - inject the test session maker into the store module
|
||||
import storage.offline_token_store as store_module
|
||||
|
||||
store_module.a_session_maker = async_session_maker
|
||||
|
||||
token_store = OfflineTokenStore('nonexistent_user', mock_config)
|
||||
|
||||
async def test_load_token_not_found(token_store):
|
||||
# Execute
|
||||
result = await token_store.load_token()
|
||||
|
||||
@@ -120,3 +104,10 @@ async def test_get_instance(mock_config):
|
||||
# Verify
|
||||
assert isinstance(result, OfflineTokenStore)
|
||||
assert result.user_id == test_user_id
|
||||
assert result.config == mock_config
|
||||
|
||||
|
||||
def test_load_store_org_token(token_manager, session_maker):
|
||||
with patch('server.auth.token_manager.session_maker', session_maker):
|
||||
token_manager.store_org_token('some-org-id', 'some-token')
|
||||
assert token_manager.load_org_token('some-org-id') == 'some-token'
|
||||
|
||||
@@ -4,12 +4,17 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database module before importing OrgMemberStore
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -9,18 +9,23 @@ import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgAuthorizationError,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
)
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_service import OrgService
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database module before importing OrgService
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgAuthorizationError,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
)
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_service import OrgService
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -5,12 +5,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.org import Org
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database module before importing OrgStore
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from storage.org import Org
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from integrations.github.github_view import get_user_proactive_conversation_setting
|
||||
from storage.org import Org
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from integrations.github.github_view import get_user_proactive_conversation_setting
|
||||
from storage.org import Org
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from storage.repository_store import RepositoryStore
|
||||
from storage.stored_repository import StoredRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository_store():
|
||||
return RepositoryStore(config=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_projects_empty_list(repository_store, async_session_maker):
|
||||
"""Test storing empty list of repositories."""
|
||||
with patch(
|
||||
'storage.repository_store.RepositoryStore.store_projects'
|
||||
) as mock_method:
|
||||
# Should handle empty list gracefully
|
||||
mock_method.return_value = None
|
||||
# Test that we handle empty repositories
|
||||
result = await repository_store.store_projects([])
|
||||
# The method should return early for empty list
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_projects_new_repositories(repository_store, async_session_maker):
|
||||
"""Test storing new repositories in the database."""
|
||||
# Setup - create repositories
|
||||
repo1 = StoredRepository(
|
||||
repo_name='owner/repo1',
|
||||
repo_id='github##123',
|
||||
is_public=False,
|
||||
)
|
||||
repo2 = StoredRepository(
|
||||
repo_name='owner/repo2',
|
||||
repo_id='github##456',
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.repository_store.a_session_maker', async_session_maker):
|
||||
await repository_store.store_projects([repo1, repo2])
|
||||
|
||||
# Verify the repositories were stored
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredRepository).filter(
|
||||
StoredRepository.repo_id.in_(['github##123', 'github##456'])
|
||||
)
|
||||
)
|
||||
repos = result.scalars().all()
|
||||
assert len(repos) == 2
|
||||
repo_ids = {r.repo_id for r in repos}
|
||||
assert 'github##123' in repo_ids
|
||||
assert 'github##456' in repo_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_projects_update_existing(repository_store, async_session_maker):
|
||||
"""Test updating existing repositories in the database."""
|
||||
# Setup - create existing repository
|
||||
existing_repo = StoredRepository(
|
||||
repo_name='owner/repo1',
|
||||
repo_id='github##123',
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
session.add(existing_repo)
|
||||
await session.commit()
|
||||
|
||||
# Execute - update the repository with new values
|
||||
updated_repo = StoredRepository(
|
||||
repo_name='owner/repo1-updated',
|
||||
repo_id='github##123',
|
||||
is_public=False, # Changed from True
|
||||
)
|
||||
|
||||
with patch('storage.repository_store.a_session_maker', async_session_maker):
|
||||
await repository_store.store_projects([updated_repo])
|
||||
|
||||
# Verify the repository was updated
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredRepository).filter(StoredRepository.repo_id == 'github##123')
|
||||
)
|
||||
repo = result.scalars().first()
|
||||
assert repo is not None
|
||||
assert repo.repo_name == 'owner/repo1-updated'
|
||||
assert repo.is_public is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_projects_mixed_new_and_existing(
|
||||
repository_store, async_session_maker
|
||||
):
|
||||
"""Test storing a mix of new and existing repositories."""
|
||||
# Setup - create one existing repository
|
||||
existing_repo = StoredRepository(
|
||||
repo_name='owner/existing-repo',
|
||||
repo_id='github##123',
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
session.add(existing_repo)
|
||||
await session.commit()
|
||||
|
||||
# Execute - store a mix of new and existing
|
||||
repos_to_store = [
|
||||
StoredRepository(
|
||||
repo_name='owner/existing-repo',
|
||||
repo_id='github##123',
|
||||
is_public=False, # Will update
|
||||
),
|
||||
StoredRepository(
|
||||
repo_name='owner/new-repo',
|
||||
repo_id='github##456',
|
||||
is_public=True,
|
||||
),
|
||||
]
|
||||
|
||||
with patch('storage.repository_store.a_session_maker', async_session_maker):
|
||||
await repository_store.store_projects(repos_to_store)
|
||||
|
||||
# Verify results
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredRepository).filter(
|
||||
StoredRepository.repo_id.in_(['github##123', 'github##456'])
|
||||
)
|
||||
)
|
||||
repos = result.scalars().all()
|
||||
assert len(repos) == 2
|
||||
|
||||
# Check the updated existing repo
|
||||
existing = next(r for r in repos if r.repo_id == 'github##123')
|
||||
assert existing.repo_name == 'owner/existing-repo'
|
||||
assert existing.is_public is False
|
||||
|
||||
# Check the new repo
|
||||
new = next(r for r in repos if r.repo_id == 'github##456')
|
||||
assert new.repo_name == 'owner/new-repo'
|
||||
assert new.is_public is True
|
||||
@@ -1,14 +1,16 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
from storage.user import User
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_call_sync_from_async():
|
||||
@@ -164,53 +166,3 @@ async def test_exists(session_maker):
|
||||
assert not await store.exists('exists-test')
|
||||
await store.save_metadata(metadata)
|
||||
assert await store.exists('exists-test')
|
||||
|
||||
|
||||
class TestGetInstance:
|
||||
"""Tests for SaasConversationStore.get_instance method.
|
||||
|
||||
The get_instance method uses async UserStore.get_user_by_id_async because
|
||||
callers now use asyncio.run_coroutine_threadsafe() to dispatch to the main
|
||||
event loop where asyncpg connections work properly.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_uses_async_get_user_by_id(self):
|
||||
"""Verify get_instance calls the async get_user_by_id_async for proper event loop handling."""
|
||||
# Arrange
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.current_org_id = UUID(user_id)
|
||||
mock_config = MagicMock(spec=OpenHandsConfig)
|
||||
|
||||
with patch(
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=mock_user),
|
||||
) as mock_async_get_user, patch(
|
||||
'storage.saas_conversation_store.session_maker'
|
||||
):
|
||||
# Act
|
||||
store = await SaasConversationStore.get_instance(mock_config, user_id)
|
||||
|
||||
# Assert
|
||||
mock_async_get_user.assert_called_once_with(user_id)
|
||||
assert store.user_id == user_id
|
||||
assert store.org_id == mock_user.current_org_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_handles_none_user(self):
|
||||
"""Verify get_instance handles case when user is not found."""
|
||||
# Arrange
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
mock_config = MagicMock(spec=OpenHandsConfig)
|
||||
|
||||
with patch(
|
||||
'storage.saas_conversation_store.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=None),
|
||||
), patch('storage.saas_conversation_store.session_maker'):
|
||||
# Act
|
||||
store = await SaasConversationStore.get_instance(mock_config, user_id)
|
||||
|
||||
# Assert
|
||||
assert store.user_id == user_id
|
||||
assert store.org_id is None
|
||||
|
||||
@@ -29,16 +29,8 @@ def mock_user():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def secrets_store(async_session_maker, mock_config):
|
||||
# Inject the test session maker into the store module
|
||||
import storage.saas_secrets_store as store_module
|
||||
|
||||
store_module.a_session_maker = async_session_maker
|
||||
|
||||
store = SaasSecretsStore('user-id', mock_config)
|
||||
# Also add it as an attribute for tests that need direct access
|
||||
store.a_session_maker = async_session_maker
|
||||
return store
|
||||
def secrets_store(session_maker, mock_config):
|
||||
return SaasSecretsStore('user-id', session_maker, mock_config)
|
||||
|
||||
|
||||
class TestSaasSecretsStore:
|
||||
@@ -115,15 +107,13 @@ class TestSaasSecretsStore:
|
||||
await secrets_store.store(user_secrets)
|
||||
|
||||
# Verify the data is encrypted in the database
|
||||
from sqlalchemy import select
|
||||
|
||||
async with secrets_store.a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredCustomSecrets)
|
||||
with secrets_store.session_maker() as session:
|
||||
stored = (
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == 'user-id')
|
||||
.filter(StoredCustomSecrets.org_id == mock_user.current_org_id)
|
||||
.first()
|
||||
)
|
||||
stored = result.scalars().first()
|
||||
|
||||
# The sensitive data should be encrypted
|
||||
assert stored.secret_value != 'sensitive_token'
|
||||
|
||||
@@ -8,7 +8,7 @@ from openhands.server.settings import Settings
|
||||
from openhands.storage.data_models.settings import Settings as DataSettings
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.a_session_maker'):
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
)
|
||||
@@ -26,21 +26,19 @@ def mock_config():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings_store(async_session_maker, mock_config):
|
||||
store = SaasSettingsStore('5594c7b6-f959-4b81-92e9-b09c206f5081', mock_config)
|
||||
store.a_session_maker = async_session_maker
|
||||
def settings_store(session_maker, mock_config):
|
||||
store = SaasSettingsStore(
|
||||
'5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker, mock_config
|
||||
)
|
||||
|
||||
# Patch the load method to read from UserSettings table directly (for testing)
|
||||
async def patched_load():
|
||||
async with store.a_session_maker() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == store.user_id
|
||||
)
|
||||
with store.session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == store.user_id)
|
||||
.first()
|
||||
)
|
||||
user_settings = result.scalars().first()
|
||||
if not user_settings:
|
||||
# Return default settings
|
||||
return Settings(
|
||||
@@ -76,31 +74,29 @@ def settings_store(async_session_maker, mock_config):
|
||||
if 'secrets_store' in item_dict:
|
||||
del item_dict['secrets_store']
|
||||
|
||||
# Encrypt the data before storing
|
||||
store._encrypt_kwargs(item_dict)
|
||||
|
||||
# Continue with the original implementation
|
||||
from sqlalchemy import select
|
||||
|
||||
async with store.a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
with store.session_maker() as session:
|
||||
existing = None
|
||||
if item_dict:
|
||||
store._encrypt_kwargs(item_dict)
|
||||
query = session.query(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == store.user_id
|
||||
)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
|
||||
# First check if we have an existing entry in the new table
|
||||
existing = query.first()
|
||||
|
||||
if existing:
|
||||
# Update existing entry
|
||||
for key, value in item_dict.items():
|
||||
if key in existing.__class__.__table__.columns:
|
||||
setattr(existing, key, value)
|
||||
await session.merge(existing)
|
||||
session.merge(existing)
|
||||
else:
|
||||
item_dict['keycloak_user_id'] = store.user_id
|
||||
settings = UserSettings(**item_dict)
|
||||
session.add(settings)
|
||||
await session.commit()
|
||||
session.commit()
|
||||
|
||||
# Replace the methods with our patched versions
|
||||
store.store = patched_store
|
||||
@@ -129,26 +125,25 @@ async def test_store_and_load_keycloak_user(settings_store):
|
||||
assert loaded_settings.agent == 'smith'
|
||||
|
||||
# Verify it was stored in user_settings table with keycloak_user_id
|
||||
from sqlalchemy import select
|
||||
|
||||
async with settings_store.a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
with settings_store.session_maker() as session:
|
||||
stored = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == '550e8400-e29b-41d4-a716-446655440000'
|
||||
)
|
||||
.first()
|
||||
)
|
||||
stored = result.scalars().first()
|
||||
assert stored is not None
|
||||
assert stored.agent == 'smith'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_returns_default_when_not_found(settings_store, async_session_maker):
|
||||
async def test_load_returns_default_when_not_found(settings_store, session_maker):
|
||||
file_store = MagicMock()
|
||||
file_store.read.side_effect = FileNotFoundError()
|
||||
|
||||
with (
|
||||
patch('storage.saas_settings_store.a_session_maker', async_session_maker),
|
||||
patch('storage.saas_settings_store.session_maker', session_maker),
|
||||
):
|
||||
loaded_settings = await settings_store.load()
|
||||
assert loaded_settings is not None
|
||||
@@ -169,15 +164,14 @@ async def test_encryption(settings_store):
|
||||
email_verified=True,
|
||||
)
|
||||
await settings_store.store(settings)
|
||||
from sqlalchemy import select
|
||||
|
||||
async with settings_store.a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
with settings_store.session_maker() as session:
|
||||
stored = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
)
|
||||
.first()
|
||||
)
|
||||
stored = result.scalars().first()
|
||||
# The stored key should be encrypted
|
||||
assert stored.llm_api_key != 'secret_key'
|
||||
# But we should be able to decrypt it when loading
|
||||
@@ -188,7 +182,7 @@ async def test_encryption(settings_store):
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_api_key_keeps_valid_key(mock_config):
|
||||
"""When the existing key is valid, it should be kept unchanged."""
|
||||
store = SaasSettingsStore('test-user-id-123', mock_config)
|
||||
store = SaasSettingsStore('test-user-id-123', MagicMock(), mock_config)
|
||||
existing_key = 'sk-existing-key'
|
||||
item = DataSettings(
|
||||
llm_model='openhands/gpt-4', llm_api_key=SecretStr(existing_key)
|
||||
@@ -211,7 +205,7 @@ async def test_ensure_api_key_generates_new_key_when_verification_fails(
|
||||
mock_config,
|
||||
):
|
||||
"""When verification fails, a new key should be generated."""
|
||||
store = SaasSettingsStore('test-user-id-123', mock_config)
|
||||
store = SaasSettingsStore('test-user-id-123', MagicMock(), mock_config)
|
||||
new_key = 'sk-new-key'
|
||||
item = DataSettings(
|
||||
llm_model='openhands/gpt-4', llm_api_key=SecretStr('sk-invalid-key')
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user