mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
26 Commits
remove-unu
...
optimize-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e9a0bbe87 | ||
|
|
c9af5edad9 | ||
|
|
0c7ce4ad48 | ||
|
|
4dab34e7b0 | ||
|
|
f8bbd352a9 | ||
|
|
17347a95f8 | ||
|
|
01ef87aaaa | ||
|
|
8059c18b57 | ||
|
|
c82ee4c7db | ||
|
|
7fdb423f99 | ||
|
|
530065dfa7 | ||
|
|
a4cd2d81a5 | ||
|
|
003b430e96 | ||
|
|
d63565186e | ||
|
|
5f42d03ec5 | ||
|
|
62241e2e00 | ||
|
|
f5197bd76a | ||
|
|
e1408f7b15 | ||
|
|
d6b8d80026 | ||
|
|
1e6a92b454 | ||
|
|
b4a3e5db2f | ||
|
|
f9d553d0bb | ||
|
|
f6f6c1ab25 | ||
|
|
c511a89426 | ||
|
|
1f82ff04d9 | ||
|
|
eec17311c7 |
@@ -193,14 +193,20 @@ class GithubManager(Manager):
|
||||
github_view.installation_id
|
||||
)
|
||||
# Store the installation token
|
||||
self.token_manager.store_org_token(
|
||||
await self.token_manager.store_org_token(
|
||||
github_view.installation_id, installation_token
|
||||
)
|
||||
# Add eyes reaction to acknowledge we've read the request
|
||||
self._add_reaction(github_view, 'eyes', installation_token)
|
||||
await self.start_job(github_view)
|
||||
|
||||
async def send_message(self, message: Message, github_view: ResolverViewInterface):
|
||||
async def send_message(self, message: str, github_view: ResolverViewInterface):
|
||||
"""Send a message to GitHub.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
github_view: The GitHub view object containing issue/PR/comment info
|
||||
"""
|
||||
installation_token = self.token_manager.load_org_token(
|
||||
github_view.installation_id
|
||||
)
|
||||
@@ -208,14 +214,12 @@ class GithubManager(Manager):
|
||||
logger.warning('Missing installation token')
|
||||
return
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(github_view, GithubInlinePRComment):
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
pr = repo.get_pull(github_view.issue_number)
|
||||
pr.create_review_comment_reply(
|
||||
comment_id=github_view.comment_id, body=outgoing_message
|
||||
comment_id=github_view.comment_id, body=message
|
||||
)
|
||||
|
||||
elif (
|
||||
@@ -226,7 +230,7 @@ class GithubManager(Manager):
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
issue = repo.get_issue(number=github_view.issue_number)
|
||||
issue.create_comment(outgoing_message)
|
||||
issue.create_comment(message)
|
||||
|
||||
else:
|
||||
logger.warning('Unsupported location')
|
||||
@@ -245,7 +249,7 @@ class GithubManager(Manager):
|
||||
)
|
||||
|
||||
try:
|
||||
msg_info = None
|
||||
msg_info: str = ''
|
||||
|
||||
try:
|
||||
user_info = github_view.user_info
|
||||
@@ -361,15 +365,13 @@ class GithubManager(Manager):
|
||||
|
||||
msg_info = get_session_expired_message(user_info.username)
|
||||
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, github_view)
|
||||
await self.send_message(msg_info, github_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Github]: Error starting job')
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', github_view
|
||||
)
|
||||
await self.send_message(msg, github_view)
|
||||
|
||||
try:
|
||||
await self.data_collector.save_data(github_view)
|
||||
|
||||
@@ -14,7 +14,6 @@ from integrations.solvability.models.summary import SolvabilitySummary
|
||||
from integrations.utils import ENABLE_SOLVABILITY_ANALYSIS
|
||||
from pydantic import ValidationError
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
@@ -90,7 +89,6 @@ async def summarize_issue_solvability(
|
||||
# Grab the user's information so we can load their LLM configuration
|
||||
store = SaasSettingsStore(
|
||||
user_id=github_view.user_info.keycloak_user_id,
|
||||
session_maker=session_maker,
|
||||
config=get_config(),
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ from jinja2 import Environment
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
@@ -153,9 +152,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
@@ -121,12 +121,11 @@ class GitlabManager(Manager):
|
||||
# Check if the user has write access to the repository
|
||||
return has_write_access
|
||||
|
||||
async def send_message(self, message: Message, gitlab_view: ResolverViewInterface):
|
||||
"""
|
||||
Send a message to GitLab based on the view type.
|
||||
async def send_message(self, message: str, gitlab_view: ResolverViewInterface):
|
||||
"""Send a message to GitLab based on the view type.
|
||||
|
||||
Args:
|
||||
message: The message to send
|
||||
message: The message content to send (plain text string)
|
||||
gitlab_view: The GitLab view object containing issue/PR/comment info
|
||||
"""
|
||||
keycloak_user_id = gitlab_view.user_info.keycloak_user_id
|
||||
@@ -138,8 +137,6 @@ class GitlabManager(Manager):
|
||||
external_auth_id=keycloak_user_id
|
||||
)
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(gitlab_view, GitlabInlineMRComment) or isinstance(
|
||||
gitlab_view, GitlabMRComment
|
||||
):
|
||||
@@ -147,7 +144,7 @@ class GitlabManager(Manager):
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
message.message,
|
||||
message,
|
||||
)
|
||||
|
||||
elif isinstance(gitlab_view, GitlabIssueComment):
|
||||
@@ -155,14 +152,14 @@ class GitlabManager(Manager):
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
outgoing_message,
|
||||
message,
|
||||
)
|
||||
elif isinstance(gitlab_view, GitlabIssue):
|
||||
await gitlab_service.reply_to_issue(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
None, # no discussion id, issue is tagged
|
||||
outgoing_message,
|
||||
message,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -262,12 +259,10 @@ class GitlabManager(Manager):
|
||||
msg_info = get_session_expired_message(user_info.username)
|
||||
|
||||
# Send the acknowledgment message
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
await self.send_message(msg_info, gitlab_view)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'[GitLab] Error starting job: {str(e)}')
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', gitlab_view
|
||||
)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
|
||||
@@ -6,7 +6,6 @@ from integrations.utils import HOST, get_oh_labels, has_exact_mention
|
||||
from jinja2 import Environment
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -78,9 +77,7 @@ class GitlabIssue(ResolverViewInterface):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
@@ -449,3 +446,5 @@ class GitlabFactory:
|
||||
previous_comments=[],
|
||||
is_mr=True,
|
||||
)
|
||||
|
||||
raise ValueError(f'Unhandled GitLab webhook event: {message}')
|
||||
|
||||
@@ -341,17 +341,25 @@ class JiraManager(Manager):
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Message,
|
||||
message: str,
|
||||
issue_key: str,
|
||||
jira_cloud_id: str,
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
):
|
||||
"""Send a comment to a Jira issue."""
|
||||
"""Send a comment to a Jira issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_key: The Jira issue key (e.g., 'PROJ-123')
|
||||
jira_cloud_id: The Jira Cloud ID
|
||||
svc_acc_email: Service account email for authentication
|
||||
svc_acc_api_key: Service account API key for authentication
|
||||
"""
|
||||
url = (
|
||||
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
|
||||
)
|
||||
data = {'body': message.message}
|
||||
data = {'body': message}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(
|
||||
url, auth=(svc_acc_email, svc_acc_api_key), json=data
|
||||
@@ -366,7 +374,7 @@ class JiraManager(Manager):
|
||||
view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg),
|
||||
msg,
|
||||
issue_key=view.payload.issue_key,
|
||||
jira_cloud_id=view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=view.jira_workspace.svc_acc_email,
|
||||
@@ -388,7 +396,7 @@ class JiraManager(Manager):
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
error_msg,
|
||||
issue_key=payload.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
|
||||
@@ -212,8 +212,6 @@ class JiraPayloadParser:
|
||||
missing.append('issue.id')
|
||||
if not issue_key:
|
||||
missing.append('issue.key')
|
||||
if not user_email:
|
||||
missing.append('user.emailAddress')
|
||||
if not display_name:
|
||||
missing.append('user.displayName')
|
||||
if not account_id:
|
||||
|
||||
@@ -418,7 +418,7 @@ class JiraDcManager(Manager):
|
||||
jira_dc_view.jira_dc_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
msg_info,
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
@@ -456,12 +456,19 @@ class JiraDcManager(Manager):
|
||||
return title, description
|
||||
|
||||
async def send_message(
|
||||
self, message: Message, issue_key: str, base_api_url: str, svc_acc_api_key: str
|
||||
self, message: str, issue_key: str, base_api_url: str, svc_acc_api_key: str
|
||||
):
|
||||
"""Send message/comment to Jira DC issue."""
|
||||
"""Send message/comment to Jira DC issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_key: The Jira issue key (e.g., 'PROJ-123')
|
||||
base_api_url: The base API URL for the Jira DC instance
|
||||
svc_acc_api_key: Service account API key for authentication
|
||||
"""
|
||||
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
|
||||
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||
data = {'body': message.message}
|
||||
data = {'body': message}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
@@ -481,7 +488,7 @@ class JiraDcManager(Manager):
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
error_msg,
|
||||
issue_key=job_context.issue_key,
|
||||
base_api_url=job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
@@ -502,7 +509,7 @@ class JiraDcManager(Manager):
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
comment_msg,
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
|
||||
@@ -19,7 +19,7 @@ class JiraDcViewInterface(ABC):
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_dc_instructions.j2')
|
||||
@@ -61,7 +61,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
@@ -113,7 +113,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_dc_existing_conversation.j2')
|
||||
@@ -167,7 +167,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
_, user_msg = await self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
|
||||
@@ -408,7 +408,7 @@ class LinearManager(Manager):
|
||||
linear_view.linear_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
msg_info,
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
@@ -473,8 +473,14 @@ class LinearManager(Manager):
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(self, message: Message, issue_id: str, api_key: str):
|
||||
"""Send message/comment to Linear issue."""
|
||||
async def send_message(self, message: str, issue_id: str, api_key: str):
|
||||
"""Send message/comment to Linear issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_id: The Linear issue ID to comment on
|
||||
api_key: The Linear API key for authentication
|
||||
"""
|
||||
query = """
|
||||
mutation CommentCreate($input: CommentCreateInput!) {
|
||||
commentCreate(input: $input) {
|
||||
@@ -485,7 +491,7 @@ class LinearManager(Manager):
|
||||
}
|
||||
}
|
||||
"""
|
||||
variables = {'input': {'issueId': issue_id, 'body': message.message}}
|
||||
variables = {'input': {'issueId': issue_id, 'body': message}}
|
||||
return await self._query_api(query, variables, api_key)
|
||||
|
||||
async def _send_error_comment(
|
||||
@@ -498,9 +504,7 @@ class LinearManager(Manager):
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg), issue_id, api_key
|
||||
)
|
||||
await self.send_message(error_msg, issue_id, api_key)
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Failed to send error comment: {str(e)}')
|
||||
|
||||
@@ -517,7 +521,7 @@ class LinearManager(Manager):
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
comment_msg,
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ class LinearViewInterface(ABC):
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class LinearNewConversationView(LinearViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('linear_instructions.j2')
|
||||
@@ -58,7 +58,7 @@ class LinearNewConversationView(LinearViewInterface):
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
@@ -110,7 +110,7 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('linear_existing_conversation.j2')
|
||||
@@ -164,7 +164,7 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
_, user_msg = await self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from integrations.models import Message, SourceType
|
||||
|
||||
@@ -12,14 +13,15 @@ class Manager(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_message(self, message: Message):
|
||||
"Send message to integration from Openhands server"
|
||||
def send_message(self, message: str, *args: Any, **kwargs: Any):
|
||||
"""Send message to integration from OpenHands server.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self):
|
||||
"Kick off a job with openhands agent"
|
||||
raise NotImplementedError
|
||||
|
||||
def create_outgoing_message(self, msg: str | dict, ephemeral: bool = False):
|
||||
return Message(source=SourceType.OPENHANDS, message=msg, ephemeral=ephemeral)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -16,8 +17,16 @@ class SourceType(str, Enum):
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Message model for incoming webhook payloads from integrations.
|
||||
|
||||
Note: This model is intended for INCOMING messages only.
|
||||
For outgoing messages (e.g., sending comments to GitHub/GitLab),
|
||||
pass strings directly to the send_message methods instead of
|
||||
wrapping them in a Message object.
|
||||
"""
|
||||
|
||||
source: SourceType
|
||||
message: str | dict
|
||||
message: dict[str, Any]
|
||||
ephemeral: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from integrations.manager import Manager
|
||||
@@ -22,7 +23,8 @@ from server.constants import SLACK_CLIENT_ID
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from slack_sdk.oauth import AuthorizeUrlGenerator
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.slack_user import SlackUser
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -63,12 +65,11 @@ class SlackManager(Manager):
|
||||
) -> tuple[SlackUser | None, UserAuth | None]:
|
||||
# We get the user and correlate them back to a user in OpenHands - if we can
|
||||
slack_user = None
|
||||
with session_maker() as session:
|
||||
slack_user = (
|
||||
session.query(SlackUser)
|
||||
.filter(SlackUser.slack_user_id == slack_user_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SlackUser).where(SlackUser.slack_user_id == slack_user_id)
|
||||
)
|
||||
slack_user = result.scalar_one_or_none()
|
||||
|
||||
# slack_view.slack_to_openhands_user = slack_user # attach user auth info to view
|
||||
|
||||
@@ -202,9 +203,7 @@ class SlackManager(Manager):
|
||||
msg = self.login_link.format(link)
|
||||
|
||||
logger.info('slack_not_yet_authenticated')
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg, ephemeral=True), slack_view
|
||||
)
|
||||
await self.send_message(msg, slack_view, ephemeral=True)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, slack_view):
|
||||
@@ -212,27 +211,40 @@ class SlackManager(Manager):
|
||||
|
||||
await self.start_job(slack_view)
|
||||
|
||||
async def send_message(self, message: Message, slack_view: SlackViewInterface):
|
||||
async def send_message(
|
||||
self,
|
||||
message: str | dict[str, Any],
|
||||
slack_view: SlackViewInterface,
|
||||
ephemeral: bool = False,
|
||||
):
|
||||
"""Send a message to Slack.
|
||||
|
||||
Args:
|
||||
message: The message content. Can be a string (for simple text) or
|
||||
a dict with 'text' and 'blocks' keys (for structured messages).
|
||||
slack_view: The Slack view object containing channel/thread info.
|
||||
ephemeral: If True, send as an ephemeral message visible only to the user.
|
||||
"""
|
||||
client = AsyncWebClient(token=slack_view.bot_access_token)
|
||||
if message.ephemeral and isinstance(message.message, str):
|
||||
if ephemeral and isinstance(message, str):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message.message,
|
||||
markdown_text=message,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
)
|
||||
elif message.ephemeral and isinstance(message.message, dict):
|
||||
elif ephemeral and isinstance(message, dict):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
text=message.message['text'],
|
||||
blocks=message.message['blocks'],
|
||||
text=message['text'],
|
||||
blocks=message['blocks'],
|
||||
)
|
||||
else:
|
||||
await client.chat_postMessage(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message.message,
|
||||
markdown_text=message,
|
||||
thread_ts=slack_view.message_ts,
|
||||
)
|
||||
|
||||
@@ -279,10 +291,7 @@ class SlackManager(Manager):
|
||||
repos, slack_view.message_ts, slack_view.thread_ts
|
||||
),
|
||||
}
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(repo_selection_msg, ephemeral=True),
|
||||
slack_view,
|
||||
)
|
||||
await self.send_message(repo_selection_msg, slack_view, ephemeral=True)
|
||||
|
||||
return False
|
||||
|
||||
@@ -368,9 +377,10 @@ class SlackManager(Manager):
|
||||
except StartingConvoException as e:
|
||||
msg_info = str(e)
|
||||
|
||||
await self.send_message(self.create_outgoing_message(msg_info), slack_view)
|
||||
await self.send_message(msg_info, slack_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Slack]: Error starting job')
|
||||
msg = 'Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(self.create_outgoing_message(msg), slack_view)
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', slack_view
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ class SlackViewInterface(SummaryExtractionTracker, ABC):
|
||||
v1_enabled: bool
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class SlackUnkownUserView(SlackViewInterface):
|
||||
team_id: str
|
||||
v1_enabled: bool
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment):
|
||||
@@ -118,7 +118,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
return block['user_id']
|
||||
return ''
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
|
||||
@@ -242,7 +242,9 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
self, jinja: Environment, provider_tokens, user_secrets
|
||||
) -> None:
|
||||
"""Create conversation using the legacy V0 system."""
|
||||
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja
|
||||
)
|
||||
|
||||
# Determine git provider from repository
|
||||
git_provider = None
|
||||
@@ -273,7 +275,9 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
|
||||
async def _create_v1_conversation(self, jinja: Environment) -> None:
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja
|
||||
)
|
||||
|
||||
# Create the initial message request
|
||||
initial_message = SendMessageRequest(
|
||||
@@ -346,7 +350,7 @@ class SlackNewConversationFromRepoFormView(SlackNewConversationView):
|
||||
class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
slack_conversation: SlackConversation
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
@@ -401,7 +405,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
instructions, _ = self._get_instructions(jinja)
|
||||
instructions, _ = await self._get_instructions(jinja)
|
||||
user_msg = MessageAction(content=instructions)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg)
|
||||
@@ -469,7 +473,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
agent_server_url = get_agent_server_url_from_sandbox(running_sandbox)
|
||||
|
||||
# 4. Prepare the message content
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg, _ = await self._get_instructions(jinja)
|
||||
|
||||
# 5. Create the message request
|
||||
send_message_request = SendMessageRequest(
|
||||
|
||||
@@ -42,11 +42,11 @@ async def store_repositories_in_db(repos: list[Repository], user_id: str) -> Non
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
repo_store.store_projects(stored_repos)
|
||||
await repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
await user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
|
||||
@@ -3,8 +3,8 @@ from uuid import UUID
|
||||
import stripe
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
@@ -15,12 +15,10 @@ stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.org_id == org_id)
|
||||
.first()
|
||||
)
|
||||
async with a_session_maker() as session:
|
||||
stmt = select(StripeCustomer).where(StripeCustomer.org_id == org_id)
|
||||
result = await session.execute(stmt)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
if stripe_customer:
|
||||
return stripe_customer.stripe_customer_id
|
||||
|
||||
@@ -74,7 +72,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id=user_id,
|
||||
@@ -82,7 +80,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
stripe_customer_id=customer.id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
@@ -108,26 +106,27 @@ async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(session: Session, user_id: str, org: Org):
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
async def migrate_customer(user_id: str, org: Org):
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StripeCustomer).where(StripeCustomer.keycloak_user_id == user_id)
|
||||
)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@@ -38,7 +38,7 @@ class ResolverViewInterface(SummaryExtractionTracker):
|
||||
is_public_repo: bool
|
||||
raw_payload: dict
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
17
enterprise/poetry.lock
generated
17
enterprise/poetry.lock
generated
@@ -1591,6 +1591,9 @@ files = [
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7"},
|
||||
{file = "cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d"},
|
||||
]
|
||||
|
||||
@@ -6168,23 +6171,23 @@ opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
|
||||
pathspec = ">=0.12.1"
|
||||
pexpect = "*"
|
||||
pg8000 = ">=1.31.5"
|
||||
pillow = ">=11.3"
|
||||
pillow = ">=12.1.1"
|
||||
playwright = ">=1.55"
|
||||
poetry = ">=2.1.2"
|
||||
prompt-toolkit = ">=3.0.50"
|
||||
protobuf = ">=5,<6"
|
||||
protobuf = ">=5.29.6,<6"
|
||||
psutil = "*"
|
||||
pybase62 = ">=1"
|
||||
pygithub = ">=2.5"
|
||||
pyjwt = ">=2.9"
|
||||
pylatexenc = "*"
|
||||
pypdf = ">=6"
|
||||
pypdf = ">=6.7.2"
|
||||
python-docx = "*"
|
||||
python-dotenv = "*"
|
||||
python-frontmatter = ">=1.1"
|
||||
python-jose = {version = ">=3.3", extras = ["cryptography"]}
|
||||
python-json-logger = ">=3.2.1"
|
||||
python-multipart = "*"
|
||||
python-multipart = ">=0.0.22"
|
||||
python-pptx = "*"
|
||||
python-socketio = "5.14"
|
||||
pythonnet = "*"
|
||||
@@ -6197,7 +6200,7 @@ setuptools = ">=78.1.1"
|
||||
shellingham = ">=1.5.4"
|
||||
sqlalchemy = {version = ">=2.0.40", extras = ["asyncio"]}
|
||||
sse-starlette = ">=3.0.2"
|
||||
starlette = ">=0.48"
|
||||
starlette = ">=0.49.1"
|
||||
tenacity = ">=8.5,<10"
|
||||
termcolor = "*"
|
||||
toml = "*"
|
||||
@@ -11961,7 +11964,7 @@ description = "Python for Window Extensions"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Windows\" or sys_platform == \"win32\""
|
||||
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
|
||||
files = [
|
||||
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
|
||||
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
|
||||
@@ -14909,4 +14912,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "1cad6029269393af67155e930c72eae2c03da02e4b3a3699823f6168c14a4218"
|
||||
content-hash = "ef037f6d6085d26166d35c56ce266439f8f1a4fea90bc43ccf15cfeaf116cae5"
|
||||
|
||||
@@ -49,7 +49,7 @@ prometheus-client = "^0.24.0"
|
||||
pandas = "^2.2.0"
|
||||
numpy = "^2.2.0"
|
||||
mcp = "^1.10.0"
|
||||
pillow = "^12.1.0"
|
||||
pillow = "^12.1.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.8.3"
|
||||
|
||||
@@ -48,15 +48,18 @@ from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.routes.user_app_settings import user_app_settings_router # noqa: E402
|
||||
from server.routes.verified_models import ( # noqa: E402
|
||||
api_router as verified_models_router,
|
||||
)
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
router as shared_conversation_router,
|
||||
)
|
||||
from server.sharing.shared_event_router import ( # noqa: E402
|
||||
router as shared_event_router,
|
||||
)
|
||||
from server.verified_models.verified_model_router import ( # noqa: E402
|
||||
api_router as verified_models_router,
|
||||
)
|
||||
from server.verified_models.verified_model_router import ( # noqa: E402
|
||||
override_llm_models_dependency,
|
||||
)
|
||||
|
||||
from openhands.server.app import app as base_app # noqa: E402
|
||||
from openhands.server.listen_socket import sio # noqa: E402
|
||||
@@ -113,6 +116,11 @@ base_app.include_router(org_router) # Add routes for organization management
|
||||
base_app.include_router(
|
||||
verified_models_router
|
||||
) # Add routes for verified models management
|
||||
|
||||
# Override the default LLM models implementation with SaaS version
|
||||
# This must happen after all routers are included
|
||||
override_llm_models_dependency(base_app)
|
||||
|
||||
base_app.include_router(invitation_router) # Add routes for org invitation management
|
||||
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
|
||||
add_github_proxy_routes(base_app)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from storage.blocked_email_domain_store import BlockedEmailDomainStore
|
||||
from storage.database import session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -23,7 +22,7 @@ class DomainBlocker:
|
||||
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
|
||||
return None
|
||||
|
||||
def is_domain_blocked(self, email: str) -> bool:
|
||||
async def is_domain_blocked(self, email: str) -> bool:
|
||||
"""Check if email domain is blocked by querying the database directly via SQL.
|
||||
|
||||
Supports blocking:
|
||||
@@ -45,7 +44,7 @@ class DomainBlocker:
|
||||
|
||||
try:
|
||||
# Query database directly via SQL to check if domain is blocked
|
||||
is_blocked = self.store.is_domain_blocked(domain)
|
||||
is_blocked = await self.store.is_domain_blocked(domain)
|
||||
|
||||
if is_blocked:
|
||||
logger.warning(f'Email domain {domain} is blocked for email: {email}')
|
||||
@@ -63,5 +62,5 @@ class DomainBlocker:
|
||||
|
||||
|
||||
# Initialize store and domain blocker
|
||||
_store = BlockedEmailDomainStore(session_maker=session_maker)
|
||||
_store = BlockedEmailDomainStore()
|
||||
domain_blocker = DomainBlocker(store=_store)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from integrations.github.github_service import SaaSGitHubService
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_utils import user_verifier
|
||||
|
||||
from enterprise.server.auth.auth_utils import user_verifier
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_types import GitHubUser
|
||||
|
||||
|
||||
@@ -18,9 +18,10 @@ from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
from server.rate_limit import RateLimiter, create_redis_rate_limiter
|
||||
from sqlalchemy import delete, select
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
@@ -124,7 +125,7 @@ class SaasUserAuth(UserAuth):
|
||||
if secrets_store:
|
||||
return secrets_store
|
||||
user_id = await self.get_user_id()
|
||||
secrets_store = SaasSecretsStore(user_id, session_maker, get_config())
|
||||
secrets_store = SaasSecretsStore(user_id, get_config())
|
||||
self.secrets_store = secrets_store
|
||||
return secrets_store
|
||||
|
||||
@@ -161,12 +162,13 @@ class SaasUserAuth(UserAuth):
|
||||
|
||||
try:
|
||||
# TODO: I think we can do this in a single request if we refactor
|
||||
with session_maker() as session:
|
||||
tokens = (
|
||||
session.query(AuthTokens)
|
||||
.where(AuthTokens.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == self.user_id
|
||||
)
|
||||
)
|
||||
tokens = result.scalars().all()
|
||||
|
||||
for token in tokens:
|
||||
idp_type = ProviderType(token.identity_provider)
|
||||
@@ -192,11 +194,11 @@ class SaasUserAuth(UserAuth):
|
||||
'idp_type': token.identity_provider,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
session.query(AuthTokens).filter(
|
||||
AuthTokens.id == token.id
|
||||
).delete()
|
||||
session.commit()
|
||||
async with a_session_maker() as session:
|
||||
await session.execute(
|
||||
delete(AuthTokens).where(AuthTokens.id == token.id)
|
||||
)
|
||||
await session.commit()
|
||||
raise
|
||||
|
||||
self.provider_tokens = MappingProxyType(provider_tokens)
|
||||
@@ -210,7 +212,7 @@ class SaasUserAuth(UserAuth):
|
||||
if settings_store:
|
||||
return settings_store
|
||||
user_id = await self.get_user_id()
|
||||
settings_store = SaasSettingsStore(user_id, session_maker, get_config())
|
||||
settings_store = SaasSettingsStore(user_id, get_config())
|
||||
self.settings_store = settings_store
|
||||
return settings_store
|
||||
|
||||
@@ -278,7 +280,7 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
return None
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
user_id = api_key_store.validate_api_key(api_key)
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
if not user_id:
|
||||
return None
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
@@ -327,7 +329,7 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
|
||||
email_verified = access_token_payload['email_verified']
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and domain_blocker.is_domain_blocked(email):
|
||||
if email and await domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
|
||||
@@ -38,9 +38,9 @@ from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
from sqlalchemy import String as SQLString
|
||||
from sqlalchemy import type_coerce
|
||||
from sqlalchemy import select, type_coerce
|
||||
from storage.auth_token_store import AuthTokenStore
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.offline_token_store import OfflineTokenStore
|
||||
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
|
||||
@@ -783,25 +783,24 @@ class TokenManager:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def store_org_token(self, installation_id: int, installation_token: str):
|
||||
async def store_org_token(self, installation_id: int, installation_token: str):
|
||||
"""Store a GitHub App installation token.
|
||||
|
||||
Args:
|
||||
installation_id: GitHub installation ID (integer or string)
|
||||
installation_token: The token to store
|
||||
"""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Ensure installation_id is a string
|
||||
str_installation_id = str(installation_id)
|
||||
# Use type_coerce to ensure SQLAlchemy treats the parameter as a string
|
||||
installation = (
|
||||
session.query(GithubAppInstallation)
|
||||
.filter(
|
||||
result = await session.execute(
|
||||
select(GithubAppInstallation).filter(
|
||||
GithubAppInstallation.installation_id
|
||||
== type_coerce(str_installation_id, SQLString)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
installation = result.scalars().first()
|
||||
if installation:
|
||||
installation.encrypted_token = self.encrypt_text(installation_token)
|
||||
else:
|
||||
@@ -811,9 +810,9 @@ class TokenManager:
|
||||
encrypted_token=self.encrypt_text(installation_token),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
def load_org_token(self, installation_id: int) -> str | None:
|
||||
async def load_org_token(self, installation_id: int) -> str | None:
|
||||
"""Load a GitHub App installation token.
|
||||
|
||||
Args:
|
||||
@@ -822,17 +821,16 @@ class TokenManager:
|
||||
Returns:
|
||||
The decrypted token if found, None otherwise
|
||||
"""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Ensure installation_id is a string and use type_coerce
|
||||
str_installation_id = str(installation_id)
|
||||
installation = (
|
||||
session.query(GithubAppInstallation)
|
||||
.filter(
|
||||
result = await session.execute(
|
||||
select(GithubAppInstallation).filter(
|
||||
GithubAppInstallation.installation_id
|
||||
== type_coerce(str_installation_id, SQLString)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
installation = result.scalars().first()
|
||||
if not installation:
|
||||
return None
|
||||
token = self.decrypt_text(installation.encrypted_token)
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
|
||||
from integrations.github.github_manager import GithubManager
|
||||
from integrations.github.github_view import GithubViewType
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.utils import (
|
||||
extract_summary_from_conversation_manager,
|
||||
get_summary_instruction,
|
||||
@@ -35,16 +34,12 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
send_summary_instruction: bool = True
|
||||
|
||||
async def _send_message_to_github(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to GitHub.
|
||||
"""Send a message to GitHub.
|
||||
|
||||
Args:
|
||||
message: The message content to send to GitHub
|
||||
"""
|
||||
try:
|
||||
# Create a message object for GitHub
|
||||
message_obj = Message(source=SourceType.OPENHANDS, message=message)
|
||||
|
||||
# Get the token manager
|
||||
token_manager = TokenManager()
|
||||
|
||||
@@ -53,8 +48,8 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
|
||||
github_manager = GithubManager(token_manager, GitHubDataCollector())
|
||||
|
||||
# Send the message
|
||||
await github_manager.send_message(message_obj, self.github_view)
|
||||
# Send the message directly as a string
|
||||
await github_manager.send_message(message, self.github_view)
|
||||
|
||||
logger.info(
|
||||
f'[GitHub] Sent summary message to {self.github_view.full_repo_name}#{self.github_view.issue_number}'
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
|
||||
from integrations.gitlab.gitlab_manager import GitlabManager
|
||||
from integrations.gitlab.gitlab_view import GitlabViewType
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.utils import (
|
||||
extract_summary_from_conversation_manager,
|
||||
get_summary_instruction,
|
||||
@@ -14,7 +13,7 @@ from storage.conversation_callback import (
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -28,8 +27,7 @@ gitlab_manager = GitlabManager(token_manager)
|
||||
|
||||
|
||||
class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to GitLab.
|
||||
"""Processor for sending conversation summaries to GitLab.
|
||||
|
||||
This processor is used to send summaries of conversations to GitLab
|
||||
when agent state changes occur.
|
||||
@@ -39,22 +37,18 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
send_summary_instruction: bool = True
|
||||
|
||||
async def _send_message_to_gitlab(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to GitLab.
|
||||
"""Send a message to GitLab.
|
||||
|
||||
Args:
|
||||
message: The message content to send to GitLab
|
||||
"""
|
||||
try:
|
||||
# Create a message object for GitHub
|
||||
message_obj = Message(source=SourceType.OPENHANDS, message=message)
|
||||
|
||||
# Get the token manager
|
||||
token_manager = TokenManager()
|
||||
gitlab_manager = GitlabManager(token_manager)
|
||||
|
||||
# Send the message
|
||||
await gitlab_manager.send_message(message_obj, self.gitlab_view)
|
||||
# Send the message directly as a string
|
||||
await gitlab_manager.send_message(message, self.gitlab_view)
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Sent summary message to {self.gitlab_view.full_repo_name}#{self.gitlab_view.issue_number}'
|
||||
@@ -111,9 +105,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
self.send_summary_instruction = False
|
||||
callback.set_processor(self)
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
# Extract the summary from the event store
|
||||
@@ -132,9 +126,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
# Mark callback as completed status
|
||||
callback.status = CallbackStatus.COMPLETED
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
|
||||
@@ -37,8 +37,7 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace_name: str
|
||||
|
||||
async def _send_comment_to_jira(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Jira issue.
|
||||
"""Send a comment to Jira issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Jira
|
||||
@@ -59,8 +58,9 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
|
||||
# Decrypt API key
|
||||
api_key = jira_manager.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
|
||||
# Send comment directly as a string
|
||||
await jira_manager.send_message(
|
||||
jira_manager.create_outgoing_message(msg=message),
|
||||
message,
|
||||
issue_key=self.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
|
||||
@@ -37,8 +37,7 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
|
||||
base_api_url: str
|
||||
|
||||
async def _send_comment_to_jira_dc(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Jira DC issue.
|
||||
"""Send a comment to Jira DC issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Jira DC
|
||||
@@ -61,8 +60,9 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
# Send comment directly as a string
|
||||
await jira_dc_manager.send_message(
|
||||
jira_dc_manager.create_outgoing_message(msg=message),
|
||||
message,
|
||||
issue_key=self.issue_key,
|
||||
base_api_url=self.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
|
||||
@@ -36,8 +36,7 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace_name: str
|
||||
|
||||
async def _send_comment_to_linear(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Linear issue.
|
||||
"""Send a comment to Linear issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Linear
|
||||
@@ -60,9 +59,9 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
# Send comment
|
||||
# Send comment directly as a string
|
||||
await linear_manager.send_message(
|
||||
linear_manager.create_outgoing_message(msg=message),
|
||||
message,
|
||||
self.issue_id,
|
||||
api_key,
|
||||
)
|
||||
|
||||
@@ -26,8 +26,7 @@ slack_manager = SlackManager(token_manager)
|
||||
|
||||
|
||||
class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to Slack.
|
||||
"""Processor for sending conversation summaries to Slack.
|
||||
|
||||
This processor is used to send summaries of conversations to Slack channels
|
||||
when agent state changes occur.
|
||||
@@ -41,14 +40,13 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
last_user_msg_id: int | None = None
|
||||
|
||||
async def _send_message_to_slack(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to Slack using the conversation_manager's send_to_event_stream method.
|
||||
"""Send a message to Slack.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Slack
|
||||
"""
|
||||
try:
|
||||
# Create a message object for Slack
|
||||
# Create a message object for Slack view creation (incoming message format)
|
||||
message_obj = Message(
|
||||
source=SourceType.SLACK,
|
||||
message={
|
||||
@@ -67,9 +65,8 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
slack_view = SlackFactory.create_slack_view_from_payload(
|
||||
message_obj, slack_user, saas_user_auth
|
||||
)
|
||||
await slack_manager.send_message(
|
||||
slack_manager.create_outgoing_message(message), slack_view
|
||||
)
|
||||
# Send the message directly as a string
|
||||
await slack_manager.send_message(message, slack_view)
|
||||
|
||||
logger.info(
|
||||
f'[Slack] Sent summary message to channel {self.channel_id} '
|
||||
|
||||
@@ -251,7 +251,7 @@ async def delete_api_key(
|
||||
)
|
||||
|
||||
# Delete the key
|
||||
success = api_key_store.delete_api_key_by_id(key_id)
|
||||
success = await api_key_store.delete_api_key_by_id(key_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -34,7 +34,8 @@ from server.services.org_invitation_service import (
|
||||
OrgInvitationService,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
@@ -270,7 +271,7 @@ async def keycloak_callback(
|
||||
# Fail open - continue with login if reCAPTCHA service unavailable
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and domain_blocker.is_domain_blocked(email):
|
||||
if email and await domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
)
|
||||
@@ -610,17 +611,20 @@ async def accept_tos(request: Request):
|
||||
|
||||
# Update user settings with TOS acceptance
|
||||
accepted_tos: datetime = datetime.now(timezone.utc)
|
||||
with session_maker() as session:
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
logger.error('User for {user_id} not found.')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User does not exist'},
|
||||
)
|
||||
user.accepted_tos = accepted_tos
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
|
||||
@@ -11,9 +11,10 @@ from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
@@ -106,16 +107,17 @@ async def get_subscription_access(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> SubscriptionAccessResponse | None:
|
||||
"""Get details of the currently valid subscription for the user."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.first()
|
||||
result = await session.execute(
|
||||
select(SubscriptionAccess).where(
|
||||
SubscriptionAccess.status == 'ACTIVE',
|
||||
SubscriptionAccess.user_id == user_id,
|
||||
SubscriptionAccess.start_at <= now,
|
||||
SubscriptionAccess.end_at >= now,
|
||||
)
|
||||
)
|
||||
subscription_access = result.scalar_one_or_none()
|
||||
if not subscription_access:
|
||||
return None
|
||||
return SubscriptionAccessResponse(
|
||||
@@ -197,7 +199,7 @@ async def create_checkout_session(
|
||||
'checkout_session_id': checkout_session.id,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
@@ -206,7 +208,7 @@ async def create_checkout_session(
|
||||
price_code='NA',
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
@@ -215,13 +217,14 @@ async def create_checkout_session(
|
||||
@billing_router.get('/success')
|
||||
async def success_callback(session_id: str, request: Request):
|
||||
# We can't use the auth cookie because of SameSite=strict
|
||||
with session_maker() as session:
|
||||
billing_session = (
|
||||
session.query(BillingSession)
|
||||
.filter(BillingSession.id == session_id)
|
||||
.filter(BillingSession.status == 'in_progress')
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(
|
||||
BillingSession.id == session_id,
|
||||
BillingSession.status == 'in_progress',
|
||||
)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
|
||||
if billing_session is None:
|
||||
# Hopefully this never happens - we get a redirect from stripe where the session does not exist
|
||||
@@ -253,7 +256,8 @@ async def success_callback(session_id: str, request: Request):
|
||||
user_team_info, billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
|
||||
org = session.query(Org).filter(Org.id == user.current_org_id).first()
|
||||
result = await session.execute(select(Org).where(Org.id == user.current_org_id))
|
||||
org = result.scalar_one_or_none()
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
@@ -279,7 +283,7 @@ async def success_callback(session_id: str, request: Request):
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
|
||||
@@ -289,13 +293,14 @@ async def success_callback(session_id: str, request: Request):
|
||||
# Callback endpoint for cancelled Stripe payments - updates billing session status
|
||||
@billing_router.get('/cancel')
|
||||
async def cancel_callback(session_id: str, request: Request):
|
||||
with session_maker() as session:
|
||||
billing_session = (
|
||||
session.query(BillingSession)
|
||||
.filter(BillingSession.id == session_id)
|
||||
.filter(BillingSession.status == 'in_progress')
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(
|
||||
BillingSession.id == session_id,
|
||||
BillingSession.status == 'in_progress',
|
||||
)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
if billing_session:
|
||||
logger.info(
|
||||
'stripe_checkout_cancel',
|
||||
@@ -307,7 +312,7 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
billing_session.status = 'cancelled'
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
@@ -11,7 +11,6 @@ from openhands.events.event_store import EventStore
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.shared import file_store
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint
|
||||
# is protected. The actual protection is provided by SetAuthCookieMiddleware
|
||||
@@ -37,23 +36,19 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
"""
|
||||
|
||||
# Verify the conversation belongs to the user
|
||||
def _verify_conversation():
|
||||
with session_maker() as session:
|
||||
metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
)
|
||||
metadata = result.scalars().first()
|
||||
if not metadata:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Conversation {conversation_id} not found',
|
||||
)
|
||||
if not metadata:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Conversation {conversation_id} not found',
|
||||
)
|
||||
|
||||
await call_sync_from_async(_verify_conversation)
|
||||
|
||||
# Create an event store to access the events directly
|
||||
# This works even when the conversation is not running
|
||||
@@ -103,12 +98,9 @@ async def submit_conversation_feedback(feedback: FeedbackRequest):
|
||||
)
|
||||
|
||||
# Add to database
|
||||
def _save_feedback():
|
||||
with session_maker() as session:
|
||||
session.add(new_feedback)
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_feedback)
|
||||
async with a_session_maker() as session:
|
||||
session.add(new_feedback)
|
||||
await session.commit()
|
||||
|
||||
return {'status': 'success', 'message': 'Feedback submitted successfully'}
|
||||
|
||||
@@ -127,30 +119,27 @@ async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_us
|
||||
return {}
|
||||
|
||||
# Query for existing feedback for all events
|
||||
def _check_feedback():
|
||||
with session_maker() as session:
|
||||
result = session.execute(
|
||||
select(ConversationFeedback).where(
|
||||
ConversationFeedback.conversation_id == conversation_id,
|
||||
ConversationFeedback.event_id.in_(event_ids),
|
||||
)
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ConversationFeedback).where(
|
||||
ConversationFeedback.conversation_id == conversation_id,
|
||||
ConversationFeedback.event_id.in_(event_ids),
|
||||
)
|
||||
)
|
||||
|
||||
# Create a mapping of event_id to feedback
|
||||
feedback_map = {
|
||||
feedback.event_id: {
|
||||
'exists': True,
|
||||
'rating': feedback.rating,
|
||||
'reason': feedback.reason,
|
||||
}
|
||||
for feedback in result.scalars()
|
||||
# Create a mapping of event_id to feedback
|
||||
feedback_map = {
|
||||
feedback.event_id: {
|
||||
'exists': True,
|
||||
'rating': feedback.rating,
|
||||
'reason': feedback.reason,
|
||||
}
|
||||
for feedback in result.scalars()
|
||||
}
|
||||
|
||||
# Build response including all events
|
||||
response = {}
|
||||
for event_id in event_ids:
|
||||
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
|
||||
# Build response including all events
|
||||
response = {}
|
||||
for event_id in event_ids:
|
||||
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
|
||||
|
||||
return response
|
||||
|
||||
return await call_sync_from_async(_check_feedback)
|
||||
return response
|
||||
|
||||
@@ -308,10 +308,11 @@ async def jira_events(
|
||||
logger.info(f'Processing new Jira webhook event: {signature}')
|
||||
redis_client.setex(key, 300, '1')
|
||||
|
||||
# Process the webhook
|
||||
# Process the webhook in background after returning response.
|
||||
# Note: For async functions, BackgroundTasks runs them in the same event loop
|
||||
# (not a thread pool), so asyncpg connections work correctly.
|
||||
message_payload = {'payload': payload}
|
||||
message = Message(source=SourceType.JIRA, message=message_payload)
|
||||
|
||||
background_tasks.add_task(jira_manager.receive_message, message)
|
||||
|
||||
return JSONResponse({'success': True})
|
||||
|
||||
@@ -7,7 +7,6 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -54,7 +53,7 @@ class DeviceTokenErrorResponse(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
oauth_device_router = APIRouter(prefix='/oauth/device')
|
||||
device_code_store = DeviceCodeStore(session_maker)
|
||||
device_code_store = DeviceCodeStore()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -90,7 +89,7 @@ async def device_authorization(
|
||||
) -> DeviceAuthorizationResponse:
|
||||
"""Start device flow by generating device and user codes."""
|
||||
try:
|
||||
device_code_entry = device_code_store.create_device_code(
|
||||
device_code_entry = await device_code_store.create_device_code(
|
||||
expires_in=DEVICE_CODE_EXPIRES_IN,
|
||||
)
|
||||
|
||||
@@ -125,7 +124,7 @@ async def device_authorization(
|
||||
async def device_token(device_code: str = Form(...)):
|
||||
"""Poll for a token until the user authorizes or the code expires."""
|
||||
try:
|
||||
device_code_entry = device_code_store.get_by_device_code(device_code)
|
||||
device_code_entry = await device_code_store.get_by_device_code(device_code)
|
||||
|
||||
if not device_code_entry:
|
||||
return _oauth_error(
|
||||
@@ -138,7 +137,9 @@ async def device_token(device_code: str = Form(...)):
|
||||
is_too_fast, current_interval = device_code_entry.check_rate_limit()
|
||||
if is_too_fast:
|
||||
# Update poll time and increase interval
|
||||
device_code_store.update_poll_time(device_code, increase_interval=True)
|
||||
await device_code_store.update_poll_time(
|
||||
device_code, increase_interval=True
|
||||
)
|
||||
logger.warning(
|
||||
'Client polling too fast, returning slow_down error',
|
||||
extra={
|
||||
@@ -154,7 +155,7 @@ async def device_token(device_code: str = Form(...)):
|
||||
)
|
||||
|
||||
# Update poll time for successful rate limit check
|
||||
device_code_store.update_poll_time(device_code, increase_interval=False)
|
||||
await device_code_store.update_poll_time(device_code, increase_interval=False)
|
||||
|
||||
if device_code_entry.is_expired():
|
||||
return _oauth_error(
|
||||
@@ -181,7 +182,7 @@ async def device_token(device_code: str = Form(...)):
|
||||
# Retrieve the specific API key for this device using the user_code
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})'
|
||||
device_api_key = api_key_store.retrieve_api_key_by_name(
|
||||
device_api_key = await api_key_store.retrieve_api_key_by_name(
|
||||
device_code_entry.keycloak_user_id, device_key_name
|
||||
)
|
||||
|
||||
@@ -238,7 +239,7 @@ async def device_verification_authenticated(
|
||||
)
|
||||
|
||||
# Validate device code
|
||||
device_code_entry = device_code_store.get_by_user_code(user_code)
|
||||
device_code_entry = await device_code_store.get_by_user_code(user_code)
|
||||
if not device_code_entry:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -252,7 +253,7 @@ async def device_verification_authenticated(
|
||||
)
|
||||
|
||||
# First, authorize the device code
|
||||
success = device_code_store.authorize_device_code(
|
||||
success = await device_code_store.authorize_device_code(
|
||||
user_code=user_code,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -289,7 +290,7 @@ async def device_verification_authenticated(
|
||||
# Clean up: revert the device authorization since API key creation failed
|
||||
# This prevents the device from being in an authorized state without an API key
|
||||
try:
|
||||
device_code_store.deny_device_code(user_code)
|
||||
await device_code_store.deny_device_code(user_code)
|
||||
logger.info(
|
||||
'Reverted device authorization due to API key creation failure',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, SecretStr, StringConstraints
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
EmailStr,
|
||||
Field,
|
||||
SecretStr,
|
||||
StringConstraints,
|
||||
field_validator,
|
||||
)
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
@@ -252,6 +259,115 @@ class OrgUpdate(BaseModel):
|
||||
condenser_max_size: int | None = Field(default=None, ge=20)
|
||||
|
||||
|
||||
class OrgLLMSettingsResponse(BaseModel):
|
||||
"""Response model for organization LLM settings."""
|
||||
|
||||
default_llm_model: str | None = None
|
||||
default_llm_base_url: str | None = None
|
||||
search_api_key: str | None = None # Masked in response
|
||||
agent: str | None = None
|
||||
confirmation_mode: bool | None = None
|
||||
security_analyzer: str | None = None
|
||||
enable_default_condenser: bool = True
|
||||
condenser_max_size: int | None = None
|
||||
default_max_iterations: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def _mask_key(secret: SecretStr | None) -> str | None:
|
||||
"""Mask an API key, showing only last 4 characters."""
|
||||
if secret is None:
|
||||
return None
|
||||
raw = secret.get_secret_value()
|
||||
if not raw:
|
||||
return None
|
||||
if len(raw) <= 4:
|
||||
return '****'
|
||||
return '****' + raw[-4:]
|
||||
|
||||
@classmethod
|
||||
def from_org(cls, org: Org) -> 'OrgLLMSettingsResponse':
|
||||
"""Create response from Org entity."""
|
||||
return cls(
|
||||
default_llm_model=org.default_llm_model,
|
||||
default_llm_base_url=org.default_llm_base_url,
|
||||
search_api_key=cls._mask_key(org.search_api_key),
|
||||
agent=org.agent,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
security_analyzer=org.security_analyzer,
|
||||
enable_default_condenser=org.enable_default_condenser
|
||||
if org.enable_default_condenser is not None
|
||||
else True,
|
||||
condenser_max_size=org.condenser_max_size,
|
||||
default_max_iterations=org.default_max_iterations,
|
||||
)
|
||||
|
||||
|
||||
class OrgMemberLLMSettings(BaseModel):
|
||||
"""LLM settings to propagate to organization members.
|
||||
|
||||
Field names match OrgMember DB columns.
|
||||
"""
|
||||
|
||||
llm_model: str | None = None
|
||||
llm_base_url: str | None = None
|
||||
max_iterations: int | None = None
|
||||
llm_api_key: str | None = None
|
||||
|
||||
def has_updates(self) -> bool:
|
||||
"""Check if any field is set (not None)."""
|
||||
return any(getattr(self, field) is not None for field in self.model_fields)
|
||||
|
||||
|
||||
class OrgLLMSettingsUpdate(BaseModel):
|
||||
"""Request model for updating organization LLM settings.
|
||||
|
||||
Field names match Org DB columns exactly.
|
||||
"""
|
||||
|
||||
default_llm_model: str | None = None
|
||||
default_llm_base_url: str | None = None
|
||||
search_api_key: str | None = None
|
||||
agent: str | None = None
|
||||
confirmation_mode: bool | None = None
|
||||
security_analyzer: str | None = None
|
||||
enable_default_condenser: bool | None = None
|
||||
condenser_max_size: int | None = Field(default=None, ge=20)
|
||||
default_max_iterations: int | None = Field(default=None, gt=0)
|
||||
llm_api_key: str | None = None
|
||||
|
||||
def has_updates(self) -> bool:
|
||||
"""Check if any field is set (not None)."""
|
||||
return any(getattr(self, field) is not None for field in self.model_fields)
|
||||
|
||||
def apply_to_org(self, org: Org) -> None:
|
||||
"""Apply non-None settings to the organization model.
|
||||
|
||||
Args:
|
||||
org: Organization entity to update in place
|
||||
"""
|
||||
for field_name in self.model_fields:
|
||||
value = getattr(self, field_name)
|
||||
# Skip llm_api_key - it's only for member propagation, not org-level
|
||||
if value is not None and field_name != 'llm_api_key':
|
||||
setattr(org, field_name, value)
|
||||
|
||||
def get_member_updates(self) -> OrgMemberLLMSettings | None:
|
||||
"""Get updates that need to be propagated to org members.
|
||||
|
||||
Returns:
|
||||
OrgMemberLLMSettings with mapped field values, or None if no member updates needed.
|
||||
Maps: default_llm_model → llm_model, default_llm_base_url → llm_base_url,
|
||||
default_max_iterations → max_iterations, llm_api_key → llm_api_key
|
||||
"""
|
||||
member_settings = OrgMemberLLMSettings(
|
||||
llm_model=self.default_llm_model,
|
||||
llm_base_url=self.default_llm_base_url,
|
||||
max_iterations=self.default_max_iterations,
|
||||
llm_api_key=self.llm_api_key,
|
||||
)
|
||||
return member_settings if member_settings.has_updates() else None
|
||||
|
||||
|
||||
class OrgMemberResponse(BaseModel):
|
||||
"""Response model for a single organization member."""
|
||||
|
||||
@@ -327,3 +443,44 @@ class MeResponse(BaseModel):
|
||||
llm_base_url=member.llm_base_url,
|
||||
status=member.status,
|
||||
)
|
||||
|
||||
|
||||
class OrgAppSettingsResponse(BaseModel):
|
||||
"""Response model for organization app settings."""
|
||||
|
||||
enable_proactive_conversation_starters: bool = True
|
||||
enable_solvability_analysis: bool | None = None
|
||||
max_budget_per_task: float | None = None
|
||||
|
||||
@classmethod
|
||||
def from_org(cls, org: Org) -> 'OrgAppSettingsResponse':
|
||||
"""Create an OrgAppSettingsResponse from an Org entity.
|
||||
|
||||
Args:
|
||||
org: The organization entity
|
||||
|
||||
Returns:
|
||||
OrgAppSettingsResponse with app settings
|
||||
"""
|
||||
return cls(
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters
|
||||
if org.enable_proactive_conversation_starters is not None
|
||||
else True,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
)
|
||||
|
||||
|
||||
class OrgAppSettingsUpdate(BaseModel):
|
||||
"""Request model for updating organization app settings."""
|
||||
|
||||
enable_proactive_conversation_starters: bool | None = None
|
||||
enable_solvability_analysis: bool | None = None
|
||||
max_budget_per_task: float | None = None
|
||||
|
||||
@field_validator('max_budget_per_task')
|
||||
@classmethod
|
||||
def validate_max_budget_per_task(cls, v: float | None) -> float | None:
|
||||
if v is not None and v <= 0:
|
||||
raise ValueError('max_budget_per_task must be greater than 0')
|
||||
return v
|
||||
|
||||
@@ -15,9 +15,13 @@ from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
MemberUpdateError,
|
||||
MeResponse,
|
||||
OrgAppSettingsResponse,
|
||||
OrgAppSettingsUpdate,
|
||||
OrgAuthorizationError,
|
||||
OrgCreate,
|
||||
OrgDatabaseError,
|
||||
OrgLLMSettingsResponse,
|
||||
OrgLLMSettingsUpdate,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
@@ -30,6 +34,14 @@ from server.routes.org_models import (
|
||||
OrphanedUserError,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from server.services.org_app_settings_service import (
|
||||
OrgAppSettingsService,
|
||||
OrgAppSettingsServiceInjector,
|
||||
)
|
||||
from server.services.org_llm_settings_service import (
|
||||
OrgLLMSettingsService,
|
||||
OrgLLMSettingsServiceInjector,
|
||||
)
|
||||
from server.services.org_member_service import OrgMemberService
|
||||
from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
@@ -40,6 +52,13 @@ from openhands.server.user_auth import get_user_id
|
||||
# Initialize API router
|
||||
org_router = APIRouter(prefix='/api/organizations', tags=['Orgs'])
|
||||
|
||||
# Create injector instance and dependency for LLM settings
|
||||
_org_llm_settings_injector = OrgLLMSettingsServiceInjector()
|
||||
org_llm_settings_service_dependency = Depends(_org_llm_settings_injector.depends)
|
||||
# Create injector instance and dependency at module level
|
||||
_org_app_settings_injector = OrgAppSettingsServiceInjector()
|
||||
org_app_settings_service_dependency = Depends(_org_app_settings_injector.depends)
|
||||
|
||||
|
||||
@org_router.get('', response_model=OrgPage)
|
||||
async def list_user_orgs(
|
||||
@@ -201,6 +220,195 @@ async def create_org(
|
||||
)
|
||||
|
||||
|
||||
@org_router.get(
|
||||
'/llm',
|
||||
response_model=OrgLLMSettingsResponse,
|
||||
dependencies=[Depends(require_permission(Permission.VIEW_LLM_SETTINGS))],
|
||||
)
|
||||
async def get_org_llm_settings(
|
||||
service: OrgLLMSettingsService = org_llm_settings_service_dependency,
|
||||
) -> OrgLLMSettingsResponse:
|
||||
"""Get LLM settings for the user's current organization.
|
||||
|
||||
This endpoint retrieves the LLM configuration settings for the
|
||||
authenticated user's current organization. All organization members
|
||||
can view these settings.
|
||||
|
||||
Args:
|
||||
service: OrgLLMSettingsService (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgLLMSettingsResponse: The organization's LLM settings
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated
|
||||
HTTPException: 403 if not a member of any organization
|
||||
HTTPException: 404 if current organization not found
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
try:
|
||||
return await service.get_org_llm_settings()
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error getting organization LLM settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve LLM settings',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post(
|
||||
'/llm',
|
||||
response_model=OrgLLMSettingsResponse,
|
||||
dependencies=[Depends(require_permission(Permission.EDIT_LLM_SETTINGS))],
|
||||
)
|
||||
async def update_org_llm_settings(
|
||||
settings: OrgLLMSettingsUpdate,
|
||||
service: OrgLLMSettingsService = org_llm_settings_service_dependency,
|
||||
) -> OrgLLMSettingsResponse:
|
||||
"""Update LLM settings for the user's current organization.
|
||||
|
||||
This endpoint updates the LLM configuration settings for the
|
||||
authenticated user's current organization. Only admins and owners
|
||||
can update these settings.
|
||||
|
||||
Args:
|
||||
settings: The LLM settings to update (only non-None fields are updated)
|
||||
service: OrgLLMSettingsService (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgLLMSettingsResponse: The updated organization's LLM settings
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated
|
||||
HTTPException: 403 if user lacks EDIT_LLM_SETTINGS permission
|
||||
HTTPException: 404 if current organization not found
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
try:
|
||||
return await service.update_org_llm_settings(settings)
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database error updating LLM settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update LLM settings',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error updating organization LLM settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update LLM settings',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get(
|
||||
'/app',
|
||||
response_model=OrgAppSettingsResponse,
|
||||
dependencies=[Depends(require_permission(Permission.MANAGE_APPLICATION_SETTINGS))],
|
||||
)
|
||||
async def get_org_app_settings(
|
||||
service: OrgAppSettingsService = org_app_settings_service_dependency,
|
||||
) -> OrgAppSettingsResponse:
|
||||
"""Get organization app settings for the user's current organization.
|
||||
|
||||
This endpoint retrieves application settings for the authenticated user's
|
||||
current organization. Access requires the MANAGE_APPLICATION_SETTINGS permission,
|
||||
which is granted to all organization members (member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
service: OrgAppSettingsService (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgAppSettingsResponse: The organization app settings
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks MANAGE_APPLICATION_SETTINGS permission
|
||||
HTTPException: 404 if current organization not found
|
||||
"""
|
||||
try:
|
||||
return await service.get_org_app_settings()
|
||||
except OrgNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Current organization not found',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error retrieving organization app settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post(
|
||||
'/app',
|
||||
response_model=OrgAppSettingsResponse,
|
||||
dependencies=[Depends(require_permission(Permission.MANAGE_APPLICATION_SETTINGS))],
|
||||
)
|
||||
async def update_org_app_settings(
|
||||
update_data: OrgAppSettingsUpdate,
|
||||
service: OrgAppSettingsService = org_app_settings_service_dependency,
|
||||
) -> OrgAppSettingsResponse:
|
||||
"""Update organization app settings for the user's current organization.
|
||||
|
||||
This endpoint updates application settings for the authenticated user's
|
||||
current organization. Access requires the MANAGE_APPLICATION_SETTINGS permission,
|
||||
which is granted to all organization members (member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
update_data: App settings update data
|
||||
service: OrgAppSettingsService (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgAppSettingsResponse: The updated organization app settings
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks MANAGE_APPLICATION_SETTINGS permission
|
||||
HTTPException: 404 if current organization not found
|
||||
HTTPException: 422 if validation errors occur (handled by FastAPI)
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
try:
|
||||
return await service.update_org_app_settings(update_data)
|
||||
except OrgNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Current organization not found',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error updating organization app settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
|
||||
async def get_org(
|
||||
org_id: UUID,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from sqlalchemy.sql import text
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.redis import create_redis_client
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -9,11 +9,11 @@ readiness_router = APIRouter()
|
||||
|
||||
|
||||
@readiness_router.get('/ready')
|
||||
def is_ready():
|
||||
async def is_ready():
|
||||
# Check database connection
|
||||
try:
|
||||
with session_maker() as session:
|
||||
session.execute(text('SELECT 1'))
|
||||
async with a_session_maker() as session:
|
||||
await session.execute(text('SELECT 1'))
|
||||
except Exception as e:
|
||||
logger.error(f'Database check failed: {str(e)}')
|
||||
raise HTTPException(
|
||||
|
||||
@@ -388,5 +388,4 @@ async def _check_idp(
|
||||
access_token.get_secret_value(), ProviderType(idp)
|
||||
):
|
||||
return default_value
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
"""API routes for managing verified LLM models (admin only)."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.email_validation import get_admin_user_id
|
||||
from storage.verified_model_store import VerifiedModelStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
api_router = APIRouter(prefix='/api/admin/verified-models', tags=['Verified Models'])
|
||||
|
||||
|
||||
class VerifiedModelCreate(BaseModel):
|
||||
model_name: str
|
||||
provider: str
|
||||
is_enabled: bool = True
|
||||
|
||||
@field_validator('model_name')
|
||||
@classmethod
|
||||
def validate_model_name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v or len(v) > 255:
|
||||
raise ValueError('model_name must be 1-255 characters')
|
||||
return v
|
||||
|
||||
@field_validator('provider')
|
||||
@classmethod
|
||||
def validate_provider(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not v or len(v) > 100:
|
||||
raise ValueError('provider must be 1-100 characters')
|
||||
return v
|
||||
|
||||
|
||||
class VerifiedModelUpdate(BaseModel):
|
||||
is_enabled: bool | None = None
|
||||
|
||||
|
||||
class VerifiedModelResponse(BaseModel):
|
||||
id: int
|
||||
model_name: str
|
||||
provider: str
|
||||
is_enabled: bool
|
||||
|
||||
|
||||
class VerifiedModelPage(BaseModel):
|
||||
"""Paginated response model for verified model list."""
|
||||
|
||||
items: list[VerifiedModelResponse]
|
||||
next_page_id: str | None = None
|
||||
|
||||
|
||||
def _to_response(model) -> VerifiedModelResponse:
|
||||
return VerifiedModelResponse(
|
||||
id=model.id,
|
||||
model_name=model.model_name,
|
||||
provider=model.provider,
|
||||
is_enabled=model.is_enabled,
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('', response_model=VerifiedModelPage)
|
||||
async def list_verified_models(
|
||||
provider: str | None = None,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int, Query(title='The max number of results in the page', gt=0, le=100)
|
||||
] = 100,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
):
|
||||
"""List all verified models, optionally filtered by provider."""
|
||||
try:
|
||||
if provider:
|
||||
all_models = VerifiedModelStore.get_models_by_provider(provider)
|
||||
else:
|
||||
all_models = VerifiedModelStore.get_all_models()
|
||||
|
||||
try:
|
||||
offset = int(page_id) if page_id else 0
|
||||
except ValueError:
|
||||
offset = 0
|
||||
page = all_models[offset : offset + limit + 1]
|
||||
has_more = len(page) > limit
|
||||
if has_more:
|
||||
page = page[:limit]
|
||||
|
||||
return VerifiedModelPage(
|
||||
items=[_to_response(m) for m in page],
|
||||
next_page_id=str(offset + limit) if has_more else None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error listing verified models')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to list verified models',
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('', response_model=VerifiedModelResponse, status_code=201)
|
||||
async def create_verified_model(
|
||||
data: VerifiedModelCreate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
):
|
||||
"""Create a new verified model."""
|
||||
try:
|
||||
model = VerifiedModelStore.create_model(
|
||||
model_name=data.model_name,
|
||||
provider=data.provider,
|
||||
is_enabled=data.is_enabled,
|
||||
)
|
||||
return _to_response(model)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error creating verified model')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create verified model',
|
||||
)
|
||||
|
||||
|
||||
@api_router.put('/{provider}/{model_name:path}', response_model=VerifiedModelResponse)
|
||||
async def update_verified_model(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
data: VerifiedModelUpdate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
):
|
||||
"""Update a verified model by provider and model name."""
|
||||
try:
|
||||
model = VerifiedModelStore.update_model(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
is_enabled=data.is_enabled,
|
||||
)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Model {provider}/{model_name} not found',
|
||||
)
|
||||
return _to_response(model)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f'Error updating verified model: {provider}/{model_name}')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update verified model',
|
||||
)
|
||||
|
||||
|
||||
@api_router.delete('/{provider}/{model_name:path}')
|
||||
async def delete_verified_model(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
):
|
||||
"""Delete a verified model by provider and model name."""
|
||||
try:
|
||||
success = VerifiedModelStore.delete_model(
|
||||
model_name=model_name, provider=provider
|
||||
)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Model {provider}/{model_name} not found',
|
||||
)
|
||||
return {'message': f'Model {provider}/{model_name} deleted'}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f'Error deleting verified model: {provider}/{model_name}')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to delete verified model',
|
||||
)
|
||||
130
enterprise/server/services/org_app_settings_service.py
Normal file
130
enterprise/server/services/org_app_settings_service.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Service class for managing organization app settings.
|
||||
|
||||
Separates business logic from route handlers.
|
||||
Uses dependency injection for db_session and user_context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
from server.routes.org_models import (
|
||||
OrgAppSettingsResponse,
|
||||
OrgAppSettingsUpdate,
|
||||
OrgNotFoundError,
|
||||
)
|
||||
from storage.org_app_settings_store import OrgAppSettingsStore
|
||||
|
||||
from openhands.app_server.services.injector import Injector, InjectorState
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrgAppSettingsService:
|
||||
"""Service for organization app settings with injected dependencies."""
|
||||
|
||||
store: OrgAppSettingsStore
|
||||
user_context: UserContext
|
||||
|
||||
async def get_org_app_settings(self) -> OrgAppSettingsResponse:
|
||||
"""Get organization app settings.
|
||||
|
||||
User ID is obtained from the injected user_context.
|
||||
|
||||
Returns:
|
||||
OrgAppSettingsResponse: The organization's app settings
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If current organization is not found
|
||||
"""
|
||||
user_id = await self.user_context.get_user_id()
|
||||
|
||||
logger.info(
|
||||
'Getting organization app settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
org = await self.store.get_current_org_by_user_id(user_id)
|
||||
|
||||
if not org:
|
||||
raise OrgNotFoundError('current')
|
||||
|
||||
return OrgAppSettingsResponse.from_org(org)
|
||||
|
||||
async def update_org_app_settings(
|
||||
self,
|
||||
update_data: OrgAppSettingsUpdate,
|
||||
) -> OrgAppSettingsResponse:
|
||||
"""Update organization app settings.
|
||||
|
||||
Only updates fields that are explicitly provided in update_data.
|
||||
User ID is obtained from the injected user_context.
|
||||
Session auto-commits at request end via DbSessionInjector.
|
||||
|
||||
Args:
|
||||
update_data: The update data from the request
|
||||
|
||||
Returns:
|
||||
OrgAppSettingsResponse: The updated organization's app settings
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If current organization is not found
|
||||
"""
|
||||
user_id = await self.user_context.get_user_id()
|
||||
|
||||
logger.info(
|
||||
'Updating organization app settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Get current org first
|
||||
org = await self.store.get_current_org_by_user_id(user_id)
|
||||
|
||||
if not org:
|
||||
raise OrgNotFoundError('current')
|
||||
|
||||
# Check if any fields are provided
|
||||
update_dict = update_data.model_dump(exclude_unset=True)
|
||||
|
||||
if not update_dict:
|
||||
# No fields to update, just return current settings
|
||||
logger.info(
|
||||
'No fields to update in app settings',
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
return OrgAppSettingsResponse.from_org(org)
|
||||
|
||||
updated_org = await self.store.update_org_app_settings(
|
||||
org_id=org.id,
|
||||
update_data=update_data,
|
||||
)
|
||||
|
||||
if not updated_org:
|
||||
raise OrgNotFoundError('current')
|
||||
|
||||
logger.info(
|
||||
'Organization app settings updated successfully',
|
||||
extra={'user_id': user_id, 'updated_fields': list(update_dict.keys())},
|
||||
)
|
||||
|
||||
return OrgAppSettingsResponse.from_org(updated_org)
|
||||
|
||||
|
||||
class OrgAppSettingsServiceInjector(Injector[OrgAppSettingsService]):
|
||||
"""Injector that composes store and user_context for OrgAppSettingsService."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[OrgAppSettingsService, None]:
|
||||
# Local imports to avoid circular dependencies
|
||||
from openhands.app_server.config import get_db_session, get_user_context
|
||||
|
||||
async with (
|
||||
get_user_context(state, request) as user_context,
|
||||
get_db_session(state, request) as db_session,
|
||||
):
|
||||
store = OrgAppSettingsStore(db_session=db_session)
|
||||
yield OrgAppSettingsService(store=store, user_context=user_context)
|
||||
130
enterprise/server/services/org_llm_settings_service.py
Normal file
130
enterprise/server/services/org_llm_settings_service.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Service class for managing organization LLM settings.
|
||||
|
||||
Separates business logic from route handlers.
|
||||
Uses dependency injection for db_session and user_context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
from server.routes.org_models import (
|
||||
OrgLLMSettingsResponse,
|
||||
OrgLLMSettingsUpdate,
|
||||
OrgNotFoundError,
|
||||
)
|
||||
from storage.org_llm_settings_store import OrgLLMSettingsStore
|
||||
|
||||
from openhands.app_server.services.injector import Injector, InjectorState
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrgLLMSettingsService:
|
||||
"""Service for org LLM settings with injected dependencies."""
|
||||
|
||||
store: OrgLLMSettingsStore
|
||||
user_context: UserContext
|
||||
|
||||
async def get_org_llm_settings(self) -> OrgLLMSettingsResponse:
|
||||
"""Get LLM settings for user's current organization.
|
||||
|
||||
User ID is obtained from the injected user_context.
|
||||
|
||||
Returns:
|
||||
OrgLLMSettingsResponse: The organization's LLM settings
|
||||
|
||||
Raises:
|
||||
ValueError: If user is not authenticated
|
||||
OrgNotFoundError: If current organization not found
|
||||
"""
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if not user_id:
|
||||
raise ValueError('User is not authenticated')
|
||||
|
||||
logger.info(
|
||||
'Getting organization LLM settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
org = await self.store.get_current_org_by_user_id(user_id)
|
||||
|
||||
if not org:
|
||||
raise OrgNotFoundError('No current organization')
|
||||
|
||||
return OrgLLMSettingsResponse.from_org(org)
|
||||
|
||||
async def update_org_llm_settings(
|
||||
self,
|
||||
update_data: OrgLLMSettingsUpdate,
|
||||
) -> OrgLLMSettingsResponse:
|
||||
"""Update LLM settings for user's current organization.
|
||||
|
||||
Only updates fields that are explicitly provided in update_data.
|
||||
User ID is obtained from the injected user_context.
|
||||
Session auto-commits at request end via DbSessionInjector.
|
||||
|
||||
Args:
|
||||
update_data: The update data from the request
|
||||
|
||||
Returns:
|
||||
OrgLLMSettingsResponse: The updated organization's LLM settings
|
||||
|
||||
Raises:
|
||||
ValueError: If user is not authenticated
|
||||
OrgNotFoundError: If current organization not found
|
||||
"""
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if not user_id:
|
||||
raise ValueError('User is not authenticated')
|
||||
|
||||
logger.info(
|
||||
'Updating organization LLM settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Check if any fields are provided
|
||||
if not update_data.has_updates():
|
||||
# No fields to update, just return current settings
|
||||
return await self.get_org_llm_settings()
|
||||
|
||||
# Get user's current org first
|
||||
org = await self.store.get_current_org_by_user_id(user_id)
|
||||
if not org:
|
||||
raise OrgNotFoundError('No current organization')
|
||||
|
||||
# Update the org LLM settings
|
||||
updated_org = await self.store.update_org_llm_settings(
|
||||
org_id=org.id,
|
||||
update_data=update_data,
|
||||
)
|
||||
|
||||
if not updated_org:
|
||||
raise OrgNotFoundError(str(org.id))
|
||||
|
||||
logger.info(
|
||||
'Organization LLM settings updated successfully',
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
return OrgLLMSettingsResponse.from_org(updated_org)
|
||||
|
||||
|
||||
class OrgLLMSettingsServiceInjector(Injector[OrgLLMSettingsService]):
|
||||
"""Injector that composes store and user_context for OrgLLMSettingsService."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[OrgLLMSettingsService, None]:
|
||||
# Local imports to avoid circular dependencies
|
||||
from openhands.app_server.config import get_db_session, get_user_context
|
||||
|
||||
async with (
|
||||
get_user_context(state, request) as user_context,
|
||||
get_db_session(state, request) as db_session,
|
||||
):
|
||||
store = OrgLLMSettingsStore(db_session=db_session)
|
||||
yield OrgLLMSettingsService(store=store, user_context=user_context)
|
||||
@@ -4,13 +4,14 @@ import pickle
|
||||
from datetime import datetime
|
||||
|
||||
from server.logger import logger
|
||||
from sqlalchemy import and_, select
|
||||
from storage.conversation_callback import (
|
||||
CallbackStatus,
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
@@ -79,15 +80,16 @@ async def invoke_conversation_callbacks(
|
||||
conversation_id: The conversation ID to process callbacks for
|
||||
observation: The AgentStateChangedObservation that triggered the callback
|
||||
"""
|
||||
with session_maker() as session:
|
||||
callbacks = (
|
||||
session.query(ConversationCallback)
|
||||
.filter(
|
||||
ConversationCallback.conversation_id == conversation_id,
|
||||
ConversationCallback.status == CallbackStatus.ACTIVE,
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ConversationCallback).filter(
|
||||
and_(
|
||||
ConversationCallback.conversation_id == conversation_id,
|
||||
ConversationCallback.status == CallbackStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
callbacks = result.scalars().all()
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
@@ -115,7 +117,7 @@ async def invoke_conversation_callbacks(
|
||||
callback.status = CallbackStatus.ERROR
|
||||
callback.updated_at = datetime.now()
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
|
||||
def update_conversation_metadata(conversation_id: str, content: dict):
|
||||
|
||||
33
enterprise/server/verified_models/verified_model_models.py
Normal file
33
enterprise/server/verified_models/verified_model_models.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, StringConstraints
|
||||
|
||||
|
||||
class VerifiedModelCreate(BaseModel):
|
||||
model_name: Annotated[
|
||||
str,
|
||||
StringConstraints(max_length=255),
|
||||
]
|
||||
provider: Annotated[
|
||||
str,
|
||||
StringConstraints(max_length=100),
|
||||
]
|
||||
is_enabled: bool = True
|
||||
|
||||
|
||||
class VerifiedModel(VerifiedModelCreate):
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class VerifiedModelUpdate(BaseModel):
|
||||
is_enabled: bool | None = None
|
||||
|
||||
|
||||
class VerifiedModelPage(BaseModel):
|
||||
"""Paginated response model for verified model list."""
|
||||
|
||||
items: list[VerifiedModel]
|
||||
next_page_id: str | None = None
|
||||
143
enterprise/server/verified_models/verified_model_router.py
Normal file
143
enterprise/server/verified_models/verified_model_router.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""API routes for managing verified LLM models (admin only)."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.verified_models.verified_model_models import (
|
||||
VerifiedModel,
|
||||
VerifiedModelCreate,
|
||||
VerifiedModelPage,
|
||||
VerifiedModelUpdate,
|
||||
)
|
||||
from server.verified_models.verified_model_service import (
|
||||
VerifiedModelService,
|
||||
verified_model_store_dependency,
|
||||
)
|
||||
|
||||
from openhands.app_server.config import get_db_session
|
||||
from openhands.server.routes import public
|
||||
from openhands.utils.llm import get_supported_llm_models
|
||||
|
||||
api_router = APIRouter(prefix='/api/admin/verified-models', tags=['Verified Models'])
|
||||
|
||||
|
||||
@api_router.get('')
|
||||
async def search_verified_models(
|
||||
provider: str | None = None,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int, Query(title='The max number of results in the page', gt=0, le=100)
|
||||
] = 100,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> VerifiedModelPage:
|
||||
"""List all verified models, optionally filtered by provider."""
|
||||
# Use SQL-level filtering and pagination
|
||||
result = await verified_model_service.search_verified_models(
|
||||
provider=provider,
|
||||
enabled_only=False, # Admin sees all models including disabled
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@api_router.post('', status_code=201)
|
||||
async def create_verified_model(
|
||||
data: VerifiedModelCreate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> VerifiedModel:
|
||||
"""Create a new verified model."""
|
||||
try:
|
||||
model = await verified_model_service.create_verified_model(
|
||||
model_name=data.model_name,
|
||||
provider=data.provider,
|
||||
is_enabled=data.is_enabled,
|
||||
)
|
||||
return model
|
||||
except ValueError as ex:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(ex),
|
||||
)
|
||||
|
||||
|
||||
@api_router.put('/{provider}/{model_name:path}')
|
||||
async def update_verified_model(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
data: VerifiedModelUpdate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> VerifiedModel:
|
||||
"""Update a verified model by provider and model name."""
|
||||
model = await verified_model_service.update_verified_model(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
is_enabled=data.is_enabled,
|
||||
)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Model {provider}/{model_name} not found',
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@api_router.delete('/{provider}/{model_name:path}')
|
||||
async def delete_verified_model(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> bool:
|
||||
"""Delete a verified model by provider and model name."""
|
||||
try:
|
||||
await verified_model_service.delete_verified_model(
|
||||
model_name=model_name, provider=provider
|
||||
)
|
||||
return True
|
||||
except ValueError as ex:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(ex),
|
||||
)
|
||||
|
||||
|
||||
async def get_saas_llm_models_dependency(request: Request) -> list[str]:
|
||||
"""SaaS implementation for the LLM models endpoint."""
|
||||
async with get_db_session(request.state, request) as db_session:
|
||||
# Prevent circular import
|
||||
from openhands.server.shared import config
|
||||
|
||||
verified_model_service = VerifiedModelService(db_session)
|
||||
page = await verified_model_service.search_verified_models(enabled_only=True)
|
||||
if page.next_page_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Too many models defined in database',
|
||||
)
|
||||
verified_models = [f'{m.provider}/{m.model_name}' for m in page.items]
|
||||
return get_supported_llm_models(config, verified_models)
|
||||
|
||||
|
||||
# Override the default implementation with SaaS implementation
|
||||
# This must be called after the app is created in saas_server.py
|
||||
def override_llm_models_dependency(app):
|
||||
"""Override the default LLM models implementation with SaaS version."""
|
||||
app.dependency_overrides[public.get_llm_models_dependency] = (
|
||||
get_saas_llm_models_dependency
|
||||
)
|
||||
242
enterprise/server/verified_models/verified_model_service.py
Normal file
242
enterprise/server/verified_models/verified_model_service.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Store for managing verified LLM models in the database."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from server.verified_models.verified_model_models import (
|
||||
VerifiedModel,
|
||||
VerifiedModelPage,
|
||||
)
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Identity,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
func,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.base import Base
|
||||
|
||||
from openhands.app_server.config import depends_db_session
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class StoredVerifiedModel(Base): # type: ignore
|
||||
"""A verified LLM model available in the model selector.
|
||||
|
||||
The composite unique constraint on (model_name, provider) allows the same
|
||||
model name to exist under different providers (e.g. 'claude-sonnet' under
|
||||
both 'openhands' and 'anthropic').
|
||||
"""
|
||||
|
||||
__tablename__ = 'verified_models'
|
||||
__table_args__ = (
|
||||
UniqueConstraint('model_name', 'provider', name='uq_verified_model_provider'),
|
||||
)
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
model_name = Column(String(255), nullable=False)
|
||||
provider = Column(String(100), nullable=False, index=True)
|
||||
is_enabled = Column(
|
||||
Boolean, nullable=False, default=True, server_default=text('true')
|
||||
)
|
||||
created_at = Column(DateTime, nullable=False, server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime, nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
def verified_model(result: StoredVerifiedModel) -> VerifiedModel:
|
||||
return VerifiedModel(
|
||||
id=result.id,
|
||||
model_name=result.model_name,
|
||||
provider=result.provider,
|
||||
is_enabled=result.is_enabled,
|
||||
created_at=result.created_at,
|
||||
updated_at=result.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerifiedModelService:
|
||||
"""Store for CRUD operations on verified models.
|
||||
|
||||
Follows the async pattern with db_session as an attribute.
|
||||
"""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def search_verified_models(
|
||||
self,
|
||||
provider: str | None = None,
|
||||
enabled_only: bool = True,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> VerifiedModelPage:
|
||||
"""Search for verified models with optional filtering and pagination.
|
||||
|
||||
Args:
|
||||
provider: Optional provider name to filter by (e.g., 'openhands', 'anthropic')
|
||||
enabled_only: If True, only return enabled models (default: True)
|
||||
page_id: Page id for pagination
|
||||
limit: Maximum number of records to return
|
||||
|
||||
Returns:
|
||||
SearchModelsResult containing items list and has_more flag
|
||||
"""
|
||||
query = select(StoredVerifiedModel)
|
||||
|
||||
# Build filters
|
||||
filters = []
|
||||
if provider:
|
||||
filters.append(StoredVerifiedModel.provider == provider)
|
||||
if enabled_only:
|
||||
filters.append(StoredVerifiedModel.is_enabled.is_(True))
|
||||
|
||||
if filters:
|
||||
query = query.where(and_(*filters))
|
||||
|
||||
# Order by provider, then model_name
|
||||
query = query.order_by(
|
||||
StoredVerifiedModel.provider, StoredVerifiedModel.model_name
|
||||
)
|
||||
|
||||
# Fetch limit + 1 to check if there are more results
|
||||
offset = int(page_id or '0')
|
||||
query = query.offset(offset).limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
results = list(result.scalars().all())
|
||||
has_more = len(results) > limit
|
||||
next_page_id = None
|
||||
|
||||
# Return only the requested number of results
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
results.pop()
|
||||
|
||||
items = [verified_model(result) for result in results]
|
||||
return VerifiedModelPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def get_model(self, model_name: str, provider: str) -> VerifiedModel | None:
|
||||
"""Get a model by its composite key (model_name, provider).
|
||||
|
||||
Args:
|
||||
model_name: The model identifier
|
||||
provider: The provider name
|
||||
"""
|
||||
query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def create_verified_model(
|
||||
self,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
is_enabled: bool = True,
|
||||
) -> VerifiedModel:
|
||||
"""Create a new verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model identifier
|
||||
provider: The provider name
|
||||
is_enabled: Whether the model is enabled (default True)
|
||||
|
||||
Raises:
|
||||
ValueError: If a model with the same (model_name, provider) already exists
|
||||
"""
|
||||
existing_query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(existing_query)
|
||||
existing = result.scalars().first()
|
||||
if existing:
|
||||
raise ValueError(f'Model {provider}/{model_name} already exists')
|
||||
|
||||
model = StoredVerifiedModel(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
is_enabled=is_enabled,
|
||||
)
|
||||
self.db_session.add(model)
|
||||
await self.db_session.commit()
|
||||
await self.db_session.refresh(model)
|
||||
logger.info(f'Created verified model: {provider}/{model_name}')
|
||||
return verified_model(model)
|
||||
|
||||
async def update_verified_model(
|
||||
self,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
is_enabled: bool | None = None,
|
||||
) -> VerifiedModel | None:
|
||||
"""Update an existing verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to update
|
||||
provider: The provider name
|
||||
is_enabled: New enabled state (optional)
|
||||
|
||||
Returns:
|
||||
The updated model if found, None otherwise
|
||||
"""
|
||||
query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
model = result.scalars().first()
|
||||
if not model:
|
||||
return None
|
||||
|
||||
if is_enabled is not None:
|
||||
model.is_enabled = is_enabled
|
||||
|
||||
await self.db_session.commit()
|
||||
await self.db_session.refresh(model)
|
||||
logger.info(f'Updated verified model: {provider}/{model_name}')
|
||||
return verified_model(model)
|
||||
|
||||
async def delete_verified_model(self, model_name: str, provider: str):
|
||||
"""Delete a verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to delete
|
||||
provider: The provider name
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
model = result.scalars().first()
|
||||
if not model:
|
||||
raise ValueError('Unknown model')
|
||||
|
||||
await self.db_session.delete(model)
|
||||
await self.db_session.commit()
|
||||
logger.info(f'Deleted verified model: {provider}/{model_name}')
|
||||
|
||||
|
||||
def verified_model_store_dependency(db_session: AsyncSession = depends_db_session()):
|
||||
return VerifiedModelService(db_session)
|
||||
@@ -5,20 +5,16 @@ import string
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import select, update
|
||||
from storage.api_key import ApiKey
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiKeyStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
API_KEY_PREFIX = 'sk-oh-'
|
||||
|
||||
def generate_api_key(self, length: int = 32) -> str:
|
||||
@@ -43,22 +39,8 @@ class ApiKeyStore:
|
||||
api_key = self.generate_api_key()
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id
|
||||
await call_sync_from_async(
|
||||
self._store_api_key, user_id, org_id, api_key, name, expires_at
|
||||
)
|
||||
|
||||
return api_key
|
||||
|
||||
def _store_api_key(
|
||||
self,
|
||||
user_id: str,
|
||||
org_id: str,
|
||||
api_key: str,
|
||||
name: str | None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Store an existing API key in the database."""
|
||||
with self.session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
@@ -67,14 +49,17 @@ class ApiKeyStore:
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(key_record)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
def validate_api_key(self, api_key: str) -> str | None:
|
||||
return api_key
|
||||
|
||||
async def validate_api_key(self, api_key: str) -> str | None:
|
||||
"""Validate an API key and return the associated user_id if valid."""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
with self.session_maker() as session:
|
||||
key_record = session.query(ApiKey).filter(ApiKey.key == api_key).first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
|
||||
key_record = result.scalars().first()
|
||||
|
||||
if not key_record:
|
||||
return None
|
||||
@@ -91,38 +76,40 @@ class ApiKeyStore:
|
||||
return None
|
||||
|
||||
# Update last_used_at timestamp
|
||||
session.execute(
|
||||
await session.execute(
|
||||
update(ApiKey)
|
||||
.where(ApiKey.id == key_record.id)
|
||||
.values(last_used_at=now)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return key_record.user_id
|
||||
|
||||
def delete_api_key(self, api_key: str) -> bool:
|
||||
async def delete_api_key(self, api_key: str) -> bool:
|
||||
"""Delete an API key by the key value."""
|
||||
with self.session_maker() as session:
|
||||
key_record = session.query(ApiKey).filter(ApiKey.key == api_key).first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
|
||||
key_record = result.scalars().first()
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
session.delete(key_record)
|
||||
session.commit()
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def delete_api_key_by_id(self, key_id: int) -> bool:
|
||||
async def delete_api_key_by_id(self, key_id: int) -> bool:
|
||||
"""Delete an API key by its ID."""
|
||||
with self.session_maker() as session:
|
||||
key_record = session.query(ApiKey).filter(ApiKey.id == key_id).first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
session.delete(key_record)
|
||||
session.commit()
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@@ -130,64 +117,55 @@ class ApiKeyStore:
|
||||
"""List all API keys for a user."""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id
|
||||
return await call_sync_from_async(self._list_api_keys_from_db, user_id, org_id)
|
||||
|
||||
def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]:
|
||||
with self.session_maker() as session:
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(
|
||||
ApiKey.user_id == user_id, ApiKey.org_id == org_id
|
||||
)
|
||||
)
|
||||
|
||||
keys = result.scalars().all()
|
||||
return [key for key in keys if key.name != 'MCP_API_KEY']
|
||||
|
||||
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id
|
||||
return await call_sync_from_async(
|
||||
self._retrieve_mcp_api_key_from_db, user_id, org_id
|
||||
)
|
||||
|
||||
def _retrieve_mcp_api_key_from_db(self, user_id: str, org_id: str) -> str | None:
|
||||
with self.session_maker() as session:
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(
|
||||
ApiKey.user_id == user_id, ApiKey.org_id == org_id
|
||||
)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
for key in keys:
|
||||
if key.name == 'MCP_API_KEY':
|
||||
return key.key
|
||||
|
||||
return None
|
||||
|
||||
def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
|
||||
async def retrieve_api_key_by_name(self, user_id: str, name: str) -> str | None:
|
||||
"""Retrieve an API key by name for a specific user."""
|
||||
with self.session_maker() as session:
|
||||
key_record = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
)
|
||||
key_record = result.scalars().first()
|
||||
return key_record.key if key_record else None
|
||||
|
||||
def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
|
||||
async def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
|
||||
"""Delete an API key by name for a specific user."""
|
||||
with self.session_maker() as session:
|
||||
key_record = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
)
|
||||
key_record = result.scalars().first()
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
session.delete(key_record)
|
||||
session.commit()
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@@ -195,4 +173,4 @@ class ApiKeyStore:
|
||||
def get_instance(cls) -> ApiKeyStore:
|
||||
"""Get an instance of the ApiKeyStore."""
|
||||
logger.debug('api_key_store.get_instance')
|
||||
return ApiKeyStore(session_maker)
|
||||
return ApiKeyStore()
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Awaitable, Callable, Dict
|
||||
from server.auth.auth_error import TokenRefreshError
|
||||
from sqlalchemy import select, text, update
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.database import a_session_maker
|
||||
|
||||
@@ -27,7 +26,6 @@ LOCK_TIMEOUT_SECONDS = 5
|
||||
class AuthTokenStore:
|
||||
keycloak_user_id: str
|
||||
idp: ProviderType
|
||||
a_session_maker: sessionmaker
|
||||
|
||||
@property
|
||||
def identity_provider_value(self) -> str:
|
||||
@@ -73,7 +71,7 @@ class AuthTokenStore:
|
||||
access_token_expires_at: Expiration time for access token (seconds since epoch)
|
||||
refresh_token_expires_at: Expiration time for refresh token (seconds since epoch)
|
||||
"""
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
async with session.begin(): # Explicitly start a transaction
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
@@ -138,7 +136,7 @@ class AuthTokenStore:
|
||||
a 401 response to prompt the user to re-authenticate.
|
||||
"""
|
||||
# FAST PATH: Check without lock first to avoid unnecessary lock contention
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).filter(
|
||||
AuthTokens.keycloak_user_id == self.keycloak_user_id,
|
||||
@@ -167,7 +165,7 @@ class AuthTokenStore:
|
||||
|
||||
# SLOW PATH: Token needs refresh, acquire lock
|
||||
try:
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
async with session.begin():
|
||||
# Set a lock timeout to prevent indefinite blocking
|
||||
# This ensures we don't hold connections forever if something goes wrong
|
||||
@@ -300,6 +298,4 @@ class AuthTokenStore:
|
||||
logger.debug(f'auth_token_store.get_instance::{keycloak_user_id}')
|
||||
if keycloak_user_id:
|
||||
keycloak_user_id = str(keycloak_user_id)
|
||||
return AuthTokenStore(
|
||||
keycloak_user_id=keycloak_user_id, idp=idp, a_session_maker=a_session_maker
|
||||
)
|
||||
return AuthTokenStore(keycloak_user_id=keycloak_user_id, idp=idp)
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import a_session_maker
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockedEmailDomainStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
def is_domain_blocked(self, domain: str) -> bool:
|
||||
async def is_domain_blocked(self, domain: str) -> bool:
|
||||
"""Check if a domain is blocked by querying the database directly.
|
||||
|
||||
This method uses SQL to efficiently check if the domain matches any blocked pattern:
|
||||
@@ -21,9 +19,9 @@ class BlockedEmailDomainStore:
|
||||
Returns:
|
||||
True if the domain is blocked, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# SQL query that handles both TLD patterns and full domain patterns
|
||||
# TLD patterns (starting with '.'): check if domain ends with the pattern
|
||||
# TLD patterns (starting with '.'): check if domain ends with it (case-insensitive)
|
||||
# Full domain patterns: check for exact match or subdomain match
|
||||
# All comparisons are case-insensitive using LOWER() to ensure consistent matching
|
||||
query = text("""
|
||||
@@ -41,5 +39,5 @@ class BlockedEmailDomainStore:
|
||||
))
|
||||
)
|
||||
""")
|
||||
result = session.execute(query, {'domain': domain}).scalar()
|
||||
return bool(result)
|
||||
result = await session.execute(query, {'domain': domain})
|
||||
return bool(result.scalar())
|
||||
|
||||
@@ -47,7 +47,11 @@ class DeviceCode(Base):
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the device code has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
return now > self.expires_at
|
||||
# Handle timezone-naive datetime from database by assuming it's UTC
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return now > expires_at
|
||||
|
||||
def is_pending(self) -> bool:
|
||||
"""Check if the device code is still pending authorization."""
|
||||
@@ -85,8 +89,13 @@ class DeviceCode(Base):
|
||||
if self.last_poll_time is None:
|
||||
return False, self.current_interval
|
||||
|
||||
# Handle timezone-naive datetime from database by assuming it's UTC
|
||||
last_poll_time = self.last_poll_time
|
||||
if last_poll_time.tzinfo is None:
|
||||
last_poll_time = last_poll_time.replace(tzinfo=timezone.utc)
|
||||
|
||||
# Calculate time since last poll
|
||||
time_since_last_poll = (now - self.last_poll_time).total_seconds()
|
||||
time_since_last_poll = (now - last_poll_time).total_seconds()
|
||||
|
||||
# Check if polling too fast
|
||||
if time_since_last_poll < self.current_interval:
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
"""Device code store for OAuth 2.0 Device Flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.database import a_session_maker
|
||||
from storage.device_code import DeviceCode
|
||||
|
||||
|
||||
class DeviceCodeStore:
|
||||
"""Store for managing OAuth 2.0 device codes."""
|
||||
|
||||
def __init__(self, session_maker):
|
||||
self.session_maker = session_maker
|
||||
|
||||
def generate_user_code(self) -> str:
|
||||
"""Generate a human-readable user code (8 characters, uppercase letters and digits)."""
|
||||
# Use a mix of uppercase letters and digits, avoiding confusing characters
|
||||
@@ -25,7 +26,7 @@ class DeviceCodeStore:
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for _ in range(128))
|
||||
|
||||
def create_device_code(
|
||||
async def create_device_code(
|
||||
self,
|
||||
expires_in: int = 600, # 10 minutes default
|
||||
max_attempts: int = 10,
|
||||
@@ -58,11 +59,10 @@ class DeviceCodeStore:
|
||||
)
|
||||
|
||||
try:
|
||||
with self.session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(device_code_entry)
|
||||
session.commit()
|
||||
session.refresh(device_code_entry)
|
||||
session.expunge(device_code_entry) # Detach from session cleanly
|
||||
await session.commit()
|
||||
await session.refresh(device_code_entry)
|
||||
return device_code_entry
|
||||
except IntegrityError:
|
||||
# Constraint violation - codes already exist, retry with new codes
|
||||
@@ -72,25 +72,23 @@ class DeviceCodeStore:
|
||||
f'Failed to generate unique device codes after {max_attempts} attempts'
|
||||
)
|
||||
|
||||
def get_by_device_code(self, device_code: str) -> DeviceCode | None:
|
||||
async def get_by_device_code(self, device_code: str) -> DeviceCode | None:
|
||||
"""Get device code entry by device code."""
|
||||
with self.session_maker() as session:
|
||||
result = (
|
||||
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(DeviceCode).filter_by(device_code=device_code)
|
||||
)
|
||||
if result:
|
||||
session.expunge(result) # Detach from session cleanly
|
||||
return result
|
||||
return result.scalars().first()
|
||||
|
||||
def get_by_user_code(self, user_code: str) -> DeviceCode | None:
|
||||
async def get_by_user_code(self, user_code: str) -> DeviceCode | None:
|
||||
"""Get device code entry by user code."""
|
||||
with self.session_maker() as session:
|
||||
result = session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
if result:
|
||||
session.expunge(result) # Detach from session cleanly
|
||||
return result
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(DeviceCode).filter_by(user_code=user_code)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
def authorize_device_code(self, user_code: str, user_id: str) -> bool:
|
||||
async def authorize_device_code(self, user_code: str, user_id: str) -> bool:
|
||||
"""Authorize a device code.
|
||||
|
||||
Args:
|
||||
@@ -100,10 +98,11 @@ class DeviceCodeStore:
|
||||
Returns:
|
||||
True if authorization was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(DeviceCode).filter_by(user_code=user_code)
|
||||
)
|
||||
device_code_entry = result.scalars().first()
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
@@ -112,11 +111,11 @@ class DeviceCodeStore:
|
||||
return False
|
||||
|
||||
device_code_entry.authorize(user_id)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def deny_device_code(self, user_code: str) -> bool:
|
||||
async def deny_device_code(self, user_code: str) -> bool:
|
||||
"""Deny a device code authorization.
|
||||
|
||||
Args:
|
||||
@@ -125,10 +124,11 @@ class DeviceCodeStore:
|
||||
Returns:
|
||||
True if denial was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(user_code=user_code).first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(DeviceCode).filter_by(user_code=user_code)
|
||||
)
|
||||
device_code_entry = result.scalars().first()
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
@@ -137,11 +137,11 @@ class DeviceCodeStore:
|
||||
return False
|
||||
|
||||
device_code_entry.deny()
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def update_poll_time(
|
||||
async def update_poll_time(
|
||||
self, device_code: str, increase_interval: bool = False
|
||||
) -> bool:
|
||||
"""Update the poll time for a device code and optionally increase interval.
|
||||
@@ -153,15 +153,16 @@ class DeviceCodeStore:
|
||||
Returns:
|
||||
True if update was successful, False otherwise
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
device_code_entry = (
|
||||
session.query(DeviceCode).filter_by(device_code=device_code).first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(DeviceCode).filter_by(device_code=device_code)
|
||||
)
|
||||
device_code_entry = result.scalars().first()
|
||||
|
||||
if not device_code_entry:
|
||||
return False
|
||||
|
||||
device_code_entry.update_poll_time(increase_interval)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@@ -5,7 +5,6 @@ from dataclasses import dataclass
|
||||
from integrations.types import GitLabResourceType
|
||||
from sqlalchemy import and_, asc, select, text, update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import a_session_maker
|
||||
from storage.gitlab_webhook import GitlabWebhook
|
||||
|
||||
@@ -14,8 +13,6 @@ from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@dataclass
|
||||
class GitlabWebhookStore:
|
||||
a_session_maker: sessionmaker = a_session_maker
|
||||
|
||||
@staticmethod
|
||||
def determine_resource_type(
|
||||
webhook: GitlabWebhook,
|
||||
@@ -44,7 +41,7 @@ class GitlabWebhookStore:
|
||||
if not project_details:
|
||||
return
|
||||
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
async with session.begin():
|
||||
# Convert GitlabWebhook objects to dictionaries for the insert
|
||||
# Using __dict__ and filtering out SQLAlchemy internal attributes and 'id'
|
||||
@@ -88,7 +85,7 @@ class GitlabWebhookStore:
|
||||
"""
|
||||
|
||||
resource_type, resource_id = GitlabWebhookStore.determine_resource_type(webhook)
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
async with session.begin():
|
||||
stmt = (
|
||||
update(GitlabWebhook).where(GitlabWebhook.project_id == resource_id)
|
||||
@@ -122,7 +119,7 @@ class GitlabWebhookStore:
|
||||
},
|
||||
)
|
||||
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
async with session.begin():
|
||||
# Create query based on the identifier provided
|
||||
if resource_type == GitLabResourceType.PROJECT:
|
||||
@@ -185,7 +182,7 @@ class GitlabWebhookStore:
|
||||
List of GitlabWebhook objects that need processing
|
||||
"""
|
||||
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
query = (
|
||||
select(GitlabWebhook)
|
||||
.where(GitlabWebhook.webhook_exists.is_(False))
|
||||
@@ -201,7 +198,7 @@ class GitlabWebhookStore:
|
||||
"""
|
||||
Get's webhook secret given the webhook uuid and admin keycloak user id
|
||||
"""
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
query = (
|
||||
select(GitlabWebhook)
|
||||
.where(
|
||||
@@ -235,7 +232,7 @@ class GitlabWebhookStore:
|
||||
Returns:
|
||||
GitlabWebhook object if found, None otherwise
|
||||
"""
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
if resource_type == GitLabResourceType.PROJECT:
|
||||
query = select(GitlabWebhook).where(
|
||||
GitlabWebhook.project_id == resource_id
|
||||
@@ -263,7 +260,7 @@ class GitlabWebhookStore:
|
||||
Returns:
|
||||
Tuple of (project_webhook_map, group_webhook_map)
|
||||
"""
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
project_webhook_map = {}
|
||||
group_webhook_map = {}
|
||||
|
||||
@@ -303,7 +300,7 @@ class GitlabWebhookStore:
|
||||
Returns:
|
||||
True if webhook was reset, False if not found
|
||||
"""
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
async with session.begin():
|
||||
if resource_type == GitLabResourceType.PROJECT:
|
||||
update_statement = (
|
||||
@@ -348,4 +345,4 @@ class GitlabWebhookStore:
|
||||
Returns:
|
||||
An instance of GitlabWebhookStore
|
||||
"""
|
||||
return GitlabWebhookStore(a_session_maker)
|
||||
return GitlabWebhookStore()
|
||||
|
||||
@@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.jira_dc_conversation import JiraDcConversation
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
@@ -24,7 +25,7 @@ class JiraDcIntegrationStore:
|
||||
) -> JiraDcWorkspace:
|
||||
"""Create a new Jira DC workspace with encrypted sensitive data."""
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
workspace = JiraDcWorkspace(
|
||||
name=name.lower(),
|
||||
admin_user_id=admin_user_id,
|
||||
@@ -34,8 +35,8 @@ class JiraDcIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
session.add(workspace)
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
logger.info(f'[Jira DC] Created workspace {workspace.name}')
|
||||
return workspace
|
||||
|
||||
@@ -48,11 +49,12 @@ class JiraDcIntegrationStore:
|
||||
status: Optional[str] = None,
|
||||
) -> JiraDcWorkspace:
|
||||
"""Update an existing Jira DC workspace with encrypted sensitive data."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Find existing workspace by ID
|
||||
workspace = (
|
||||
session.query(JiraDcWorkspace).filter(JiraDcWorkspace.id == id).first()
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(JiraDcWorkspace.id == id)
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
|
||||
if not workspace:
|
||||
raise ValueError(f'Workspace with ID "{id}" not found')
|
||||
@@ -69,8 +71,8 @@ class JiraDcIntegrationStore:
|
||||
if status is not None:
|
||||
workspace.status = status
|
||||
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Jira DC] Updated workspace {workspace.name}')
|
||||
return workspace
|
||||
@@ -91,10 +93,10 @@ class JiraDcIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(jira_dc_user)
|
||||
session.commit()
|
||||
session.refresh(jira_dc_user)
|
||||
await session.commit()
|
||||
await session.refresh(jira_dc_user)
|
||||
|
||||
logger.info(
|
||||
f'[Jira DC] Created user {jira_dc_user.id} for workspace {jira_dc_workspace_id}'
|
||||
@@ -103,94 +105,91 @@ class JiraDcIntegrationStore:
|
||||
|
||||
async def get_workspace_by_id(self, workspace_id: int) -> Optional[JiraDcWorkspace]:
|
||||
"""Retrieve workspace by ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcWorkspace)
|
||||
.filter(JiraDcWorkspace.id == workspace_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(JiraDcWorkspace.id == workspace_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_workspace_by_name(
|
||||
self, workspace_name: str
|
||||
) -> Optional[JiraDcWorkspace]:
|
||||
"""Retrieve workspace by name."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcWorkspace)
|
||||
.filter(JiraDcWorkspace.name == workspace_name.lower())
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(
|
||||
JiraDcWorkspace.name == workspace_name.lower()
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_active_workspace(
|
||||
self, keycloak_user_id: str
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Retrieve user by Keycloak user ID."""
|
||||
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_keycloak_id_and_workspace(
|
||||
self, keycloak_user_id: str, jira_dc_workspace_id: int
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Get Jira DC user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_user(
|
||||
self, jira_dc_user_id: str, jira_dc_workspace_id: int
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Get Jira DC user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.jira_dc_user_id == jira_dc_user_id,
|
||||
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_user_by_keycloak_id_and_workspace(
|
||||
self, keycloak_user_id: str, jira_dc_workspace_id: int
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Get Jira DC user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraDcUser.jira_dc_workspace_id == jira_dc_workspace_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_user_integration_status(
|
||||
self, keycloak_user_id: str, status: str
|
||||
) -> JiraDcUser:
|
||||
"""Update the status of a Jira DC user mapping."""
|
||||
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(JiraDcUser)
|
||||
.filter(JiraDcUser.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.keycloak_user_id == keycloak_user_id
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise ValueError(
|
||||
@@ -198,37 +197,35 @@ class JiraDcIntegrationStore:
|
||||
)
|
||||
|
||||
user.status = status
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
logger.info(f'[Jira DC] Updated user {keycloak_user_id} status to {status}')
|
||||
return user
|
||||
|
||||
async def deactivate_workspace(self, workspace_id: int):
|
||||
"""Deactivate the workspace and all user links for a given workspace."""
|
||||
with session_maker() as session:
|
||||
users = (
|
||||
session.query(JiraDcUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcUser).where(
|
||||
JiraDcUser.jira_dc_workspace_id == workspace_id,
|
||||
JiraDcUser.status == 'active',
|
||||
)
|
||||
.all()
|
||||
)
|
||||
users = result.scalars().all()
|
||||
|
||||
for user in users:
|
||||
user.status = 'inactive'
|
||||
session.add(user)
|
||||
|
||||
workspace = (
|
||||
session.query(JiraDcWorkspace)
|
||||
.filter(JiraDcWorkspace.id == workspace_id)
|
||||
.first()
|
||||
result = await session.execute(
|
||||
select(JiraDcWorkspace).where(JiraDcWorkspace.id == workspace_id)
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
if workspace:
|
||||
workspace.status = 'inactive'
|
||||
session.add(workspace)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f'[Jira DC] Deactivated all user links for workspace {workspace_id}'
|
||||
@@ -238,23 +235,22 @@ class JiraDcIntegrationStore:
|
||||
self, jira_dc_conversation: JiraDcConversation
|
||||
) -> None:
|
||||
"""Create a new Jira DC conversation record."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(jira_dc_conversation)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
async def get_user_conversations_by_issue_id(
|
||||
self, issue_id: str, jira_dc_user_id: int
|
||||
) -> JiraDcConversation | None:
|
||||
"""Get a Jira DC conversation by issue ID and jira dc user ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcConversation)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraDcConversation).where(
|
||||
JiraDcConversation.issue_id == issue_id,
|
||||
JiraDcConversation.jira_dc_user_id == jira_dc_user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> JiraDcIntegrationStore:
|
||||
|
||||
@@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import and_, select
|
||||
from storage.database import a_session_maker
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
@@ -35,10 +36,10 @@ class JiraIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(workspace)
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Jira] Created workspace {workspace.name}')
|
||||
return workspace
|
||||
@@ -53,11 +54,12 @@ class JiraIntegrationStore:
|
||||
status: Optional[str] = None,
|
||||
) -> JiraWorkspace:
|
||||
"""Update an existing Jira workspace with encrypted sensitive data."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Find existing workspace by ID
|
||||
workspace = (
|
||||
session.query(JiraWorkspace).filter(JiraWorkspace.id == id).first()
|
||||
result = await session.execute(
|
||||
select(JiraWorkspace).filter(JiraWorkspace.id == id)
|
||||
)
|
||||
workspace = result.scalars().first()
|
||||
|
||||
if not workspace:
|
||||
raise ValueError(f'Workspace with ID "{id}" not found')
|
||||
@@ -77,11 +79,11 @@ class JiraIntegrationStore:
|
||||
if status is not None:
|
||||
workspace.status = status
|
||||
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Jira] Updated workspace {workspace.name}')
|
||||
return workspace
|
||||
logger.info(f'[Jira] Updated workspace {workspace.name}')
|
||||
return workspace
|
||||
|
||||
async def create_workspace_link(
|
||||
self,
|
||||
@@ -99,10 +101,10 @@ class JiraIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(jira_user)
|
||||
session.commit()
|
||||
session.refresh(jira_user)
|
||||
await session.commit()
|
||||
await session.refresh(jira_user)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Created user {jira_user.id} for workspace {jira_workspace_id}'
|
||||
@@ -111,75 +113,77 @@ class JiraIntegrationStore:
|
||||
|
||||
async def get_workspace_by_id(self, workspace_id: int) -> Optional[JiraWorkspace]:
|
||||
"""Retrieve workspace by ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraWorkspace)
|
||||
.filter(JiraWorkspace.id == workspace_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraWorkspace).filter(JiraWorkspace.id == workspace_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_workspace_by_name(self, workspace_name: str) -> JiraWorkspace | None:
|
||||
"""Retrieve workspace by name."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraWorkspace)
|
||||
.filter(JiraWorkspace.name == workspace_name.lower())
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraWorkspace).filter(
|
||||
JiraWorkspace.name == workspace_name.lower()
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_by_active_workspace(
|
||||
self, keycloak_user_id: str
|
||||
) -> Optional[JiraUser]:
|
||||
"""Get Jira user by Keycloak user ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraUser)
|
||||
.filter(
|
||||
JiraUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraUser.status == 'active',
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraUser).filter(
|
||||
and_(
|
||||
JiraUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraUser.status == 'active',
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_by_keycloak_id_and_workspace(
|
||||
self, keycloak_user_id: str, jira_workspace_id: int
|
||||
) -> Optional[JiraUser]:
|
||||
"""Get Jira user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraUser)
|
||||
.filter(
|
||||
JiraUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraUser.jira_workspace_id == jira_workspace_id,
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraUser).filter(
|
||||
and_(
|
||||
JiraUser.keycloak_user_id == keycloak_user_id,
|
||||
JiraUser.jira_workspace_id == jira_workspace_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_active_user(
|
||||
self, jira_user_id: str, jira_workspace_id: int
|
||||
) -> Optional[JiraUser]:
|
||||
"""Get Jira user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraUser)
|
||||
.filter(
|
||||
JiraUser.jira_user_id == jira_user_id,
|
||||
JiraUser.jira_workspace_id == jira_workspace_id,
|
||||
JiraUser.status == 'active',
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraUser).filter(
|
||||
and_(
|
||||
JiraUser.jira_user_id == jira_user_id,
|
||||
JiraUser.jira_workspace_id == jira_workspace_id,
|
||||
JiraUser.status == 'active',
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def update_user_integration_status(
|
||||
self, keycloak_user_id: str, status: str
|
||||
) -> JiraUser:
|
||||
"""Update Jira user integration status."""
|
||||
with session_maker() as session:
|
||||
jira_user = (
|
||||
session.query(JiraUser)
|
||||
.filter(JiraUser.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraUser).filter(JiraUser.keycloak_user_id == keycloak_user_id)
|
||||
)
|
||||
jira_user = result.scalars().first()
|
||||
|
||||
if not jira_user:
|
||||
raise ValueError(
|
||||
@@ -187,60 +191,61 @@ class JiraIntegrationStore:
|
||||
)
|
||||
|
||||
jira_user.status = status
|
||||
session.commit()
|
||||
session.refresh(jira_user)
|
||||
await session.commit()
|
||||
await session.refresh(jira_user)
|
||||
|
||||
logger.info(f'[Jira] Updated user {keycloak_user_id} status to {status}')
|
||||
return jira_user
|
||||
|
||||
async def deactivate_workspace(self, workspace_id: int):
|
||||
"""Deactivate the workspace and all user links for a given workspace."""
|
||||
with session_maker() as session:
|
||||
users = (
|
||||
session.query(JiraUser)
|
||||
.filter(
|
||||
JiraUser.jira_workspace_id == workspace_id,
|
||||
JiraUser.status == 'active',
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraUser).filter(
|
||||
and_(
|
||||
JiraUser.jira_workspace_id == workspace_id,
|
||||
JiraUser.status == 'active',
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
users = result.scalars().all()
|
||||
|
||||
for user in users:
|
||||
user.status = 'inactive'
|
||||
session.add(user)
|
||||
|
||||
workspace = (
|
||||
session.query(JiraWorkspace)
|
||||
.filter(JiraWorkspace.id == workspace_id)
|
||||
.first()
|
||||
result = await session.execute(
|
||||
select(JiraWorkspace).filter(JiraWorkspace.id == workspace_id)
|
||||
)
|
||||
workspace = result.scalars().first()
|
||||
if workspace:
|
||||
workspace.status = 'inactive'
|
||||
session.add(workspace)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(f'[Jira] Deactivated all user links for workspace {workspace_id}')
|
||||
|
||||
async def create_conversation(self, jira_conversation: JiraConversation) -> None:
|
||||
"""Create a new Jira conversation record."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(jira_conversation)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
async def get_user_conversations_by_issue_id(
|
||||
self, issue_id: str, jira_user_id: int
|
||||
) -> JiraConversation | None:
|
||||
"""Get a Jira conversation by issue ID and jira user ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraConversation)
|
||||
.filter(
|
||||
JiraConversation.issue_id == issue_id,
|
||||
JiraConversation.jira_user_id == jira_user_id,
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(JiraConversation).filter(
|
||||
and_(
|
||||
JiraConversation.issue_id == issue_id,
|
||||
JiraConversation.jira_user_id == jira_user_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> JiraIntegrationStore:
|
||||
|
||||
@@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.linear_conversation import LinearConversation
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
@@ -35,10 +36,10 @@ class LinearIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(workspace)
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Linear] Created workspace {workspace.name}')
|
||||
return workspace
|
||||
@@ -53,11 +54,12 @@ class LinearIntegrationStore:
|
||||
status: Optional[str] = None,
|
||||
) -> LinearWorkspace:
|
||||
"""Update an existing Linear workspace with encrypted sensitive data."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Find existing workspace by ID
|
||||
workspace = (
|
||||
session.query(LinearWorkspace).filter(LinearWorkspace.id == id).first()
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(LinearWorkspace.id == id)
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
|
||||
if not workspace:
|
||||
raise ValueError(f'Workspace with ID "{id}" not found')
|
||||
@@ -77,8 +79,8 @@ class LinearIntegrationStore:
|
||||
if status is not None:
|
||||
workspace.status = status
|
||||
|
||||
session.commit()
|
||||
session.refresh(workspace)
|
||||
await session.commit()
|
||||
await session.refresh(workspace)
|
||||
|
||||
logger.info(f'[Linear] Updated workspace {workspace.name}')
|
||||
return workspace
|
||||
@@ -98,10 +100,10 @@ class LinearIntegrationStore:
|
||||
status=status,
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(linear_user)
|
||||
session.commit()
|
||||
session.refresh(linear_user)
|
||||
await session.commit()
|
||||
await session.refresh(linear_user)
|
||||
|
||||
logger.info(
|
||||
f'[Linear] Created user {linear_user.id} for workspace {linear_workspace_id}'
|
||||
@@ -110,77 +112,75 @@ class LinearIntegrationStore:
|
||||
|
||||
async def get_workspace_by_id(self, workspace_id: int) -> Optional[LinearWorkspace]:
|
||||
"""Retrieve workspace by ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearWorkspace)
|
||||
.filter(LinearWorkspace.id == workspace_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(LinearWorkspace.id == workspace_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_workspace_by_name(
|
||||
self, workspace_name: str
|
||||
) -> Optional[LinearWorkspace]:
|
||||
"""Retrieve workspace by name."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearWorkspace)
|
||||
.filter(LinearWorkspace.name == workspace_name.lower())
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(
|
||||
LinearWorkspace.name == workspace_name.lower()
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_active_workspace(
|
||||
self, keycloak_user_id: str
|
||||
) -> LinearUser | None:
|
||||
"""Get Linear user by Keycloak user ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
LinearUser.keycloak_user_id == keycloak_user_id,
|
||||
LinearUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_by_keycloak_id_and_workspace(
|
||||
self, keycloak_user_id: str, linear_workspace_id: int
|
||||
) -> Optional[LinearUser]:
|
||||
"""Get Linear user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
LinearUser.keycloak_user_id == keycloak_user_id,
|
||||
LinearUser.linear_workspace_id == linear_workspace_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_user(
|
||||
self, linear_user_id: str, linear_workspace_id: int
|
||||
) -> Optional[LinearUser]:
|
||||
"""Get Linear user by Keycloak user ID and workspace ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
LinearUser.linear_user_id == linear_user_id,
|
||||
LinearUser.linear_workspace_id == linear_workspace_id,
|
||||
LinearUser.status == 'active',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_user_integration_status(
|
||||
self, keycloak_user_id: str, status: str
|
||||
) -> LinearUser:
|
||||
"""Update Linear user integration status."""
|
||||
with session_maker() as session:
|
||||
linear_user = (
|
||||
session.query(LinearUser)
|
||||
.filter(LinearUser.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
LinearUser.keycloak_user_id == keycloak_user_id
|
||||
)
|
||||
)
|
||||
linear_user = result.scalar_one_or_none()
|
||||
|
||||
if not linear_user:
|
||||
raise ValueError(
|
||||
@@ -188,38 +188,36 @@ class LinearIntegrationStore:
|
||||
)
|
||||
|
||||
linear_user.status = status
|
||||
session.commit()
|
||||
session.refresh(linear_user)
|
||||
await session.commit()
|
||||
await session.refresh(linear_user)
|
||||
|
||||
logger.info(f'[Linear] Updated user {keycloak_user_id} status to {status}')
|
||||
return linear_user
|
||||
|
||||
async def deactivate_workspace(self, workspace_id: int):
|
||||
"""Deactivate the workspace and all user links for a given workspace."""
|
||||
with session_maker() as session:
|
||||
users = (
|
||||
session.query(LinearUser)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearUser).where(
|
||||
LinearUser.linear_workspace_id == workspace_id,
|
||||
LinearUser.status == 'active',
|
||||
)
|
||||
.all()
|
||||
)
|
||||
users = result.scalars().all()
|
||||
|
||||
for user in users:
|
||||
user.status = 'inactive'
|
||||
session.add(user)
|
||||
|
||||
workspace = (
|
||||
session.query(LinearWorkspace)
|
||||
.filter(LinearWorkspace.id == workspace_id)
|
||||
.first()
|
||||
result = await session.execute(
|
||||
select(LinearWorkspace).where(LinearWorkspace.id == workspace_id)
|
||||
)
|
||||
workspace = result.scalar_one_or_none()
|
||||
if workspace:
|
||||
workspace.status = 'inactive'
|
||||
session.add(workspace)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(f'[Jira] Deactivated all user links for workspace {workspace_id}')
|
||||
|
||||
@@ -227,23 +225,22 @@ class LinearIntegrationStore:
|
||||
self, linear_conversation: LinearConversation
|
||||
) -> None:
|
||||
"""Create a new Linear conversation record."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(linear_conversation)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
async def get_user_conversations_by_issue_id(
|
||||
self, issue_id: str, linear_user_id: int
|
||||
) -> LinearConversation | None:
|
||||
"""Get a Linear conversation by issue ID and linear user ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(LinearConversation)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(LinearConversation).where(
|
||||
LinearConversation.issue_id == issue_id,
|
||||
LinearConversation.linear_user_id == linear_user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> LinearIntegrationStore:
|
||||
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -13,17 +13,17 @@ from openhands.core.logger import openhands_logger as logger
|
||||
@dataclass
|
||||
class OfflineTokenStore:
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
|
||||
async def store_token(self, offline_token: str) -> None:
|
||||
"""Store an offline token in the database."""
|
||||
with self.session_maker() as session:
|
||||
token_record = (
|
||||
session.query(StoredOfflineToken)
|
||||
.filter(StoredOfflineToken.user_id == self.user_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredOfflineToken).where(
|
||||
StoredOfflineToken.user_id == self.user_id
|
||||
)
|
||||
)
|
||||
token_record = result.scalar_one_or_none()
|
||||
|
||||
if token_record:
|
||||
token_record.offline_token = offline_token
|
||||
@@ -32,16 +32,17 @@ class OfflineTokenStore:
|
||||
user_id=self.user_id, offline_token=offline_token
|
||||
)
|
||||
session.add(token_record)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
async def load_token(self) -> str | None:
|
||||
"""Load an offline token from the database."""
|
||||
with self.session_maker() as session:
|
||||
token_record = (
|
||||
session.query(StoredOfflineToken)
|
||||
.filter(StoredOfflineToken.user_id == self.user_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredOfflineToken).where(
|
||||
StoredOfflineToken.user_id == self.user_id
|
||||
)
|
||||
)
|
||||
token_record = result.scalar_one_or_none()
|
||||
|
||||
if not token_record:
|
||||
return None
|
||||
@@ -56,4 +57,4 @@ class OfflineTokenStore:
|
||||
logger.debug(f'offline_token_store.get_instance::{user_id}')
|
||||
if user_id:
|
||||
user_id = str(user_id)
|
||||
return OfflineTokenStore(user_id, session_maker, config)
|
||||
return OfflineTokenStore(user_id, config)
|
||||
|
||||
105
enterprise/storage/org_app_settings_store.py
Normal file
105
enterprise/storage/org_app_settings_store.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Store class for managing organization app settings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.routes.org_models import OrgAppSettingsUpdate
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrgAppSettingsStore:
|
||||
"""Store for organization app settings with injected db_session."""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def get_current_org_by_user_id(self, user_id: str) -> Org | None:
|
||||
"""Get the current organization for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID (Keycloak user ID)
|
||||
|
||||
Returns:
|
||||
Org: The organization object, or None if not found
|
||||
"""
|
||||
# Get user with their current_org_id
|
||||
result = await self.db_session.execute(
|
||||
select(User).filter(User.id == UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
org_id = user.current_org_id
|
||||
if not org_id:
|
||||
return None
|
||||
|
||||
# Get the organization
|
||||
result = await self.db_session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
|
||||
if not org:
|
||||
return None
|
||||
|
||||
return await self._validate_org_version(org)
|
||||
|
||||
async def _validate_org_version(self, org: Org) -> Org:
|
||||
"""Check if we need to update org version.
|
||||
|
||||
Args:
|
||||
org: The organization to validate
|
||||
|
||||
Returns:
|
||||
Org: The validated (and potentially updated) organization
|
||||
"""
|
||||
if org.org_version < ORG_SETTINGS_VERSION:
|
||||
org.org_version = ORG_SETTINGS_VERSION
|
||||
org.default_llm_model = get_default_litellm_model()
|
||||
org.llm_base_url = LITE_LLM_API_URL
|
||||
await self.db_session.flush()
|
||||
await self.db_session.refresh(org)
|
||||
|
||||
return org
|
||||
|
||||
async def update_org_app_settings(
|
||||
self, org_id: UUID, update_data: OrgAppSettingsUpdate
|
||||
) -> Org | None:
|
||||
"""Update organization app settings.
|
||||
|
||||
Only updates fields that are explicitly provided in update_data.
|
||||
Uses flush() - commit happens at request end via DbSessionInjector.
|
||||
|
||||
Args:
|
||||
org_id: The organization's ID
|
||||
update_data: Pydantic model with fields to update
|
||||
|
||||
Returns:
|
||||
Org: The updated organization object, or None if not found
|
||||
"""
|
||||
result = await self.db_session.execute(
|
||||
select(Org).filter(Org.id == org_id).with_for_update()
|
||||
)
|
||||
org = result.scalars().first()
|
||||
|
||||
if not org:
|
||||
return None
|
||||
|
||||
# Update only explicitly provided fields
|
||||
for field, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(org, field, value)
|
||||
|
||||
# flush instead of commit - DbSessionInjector auto-commits at request end
|
||||
await self.db_session.flush()
|
||||
await self.db_session.refresh(org)
|
||||
return org
|
||||
83
enterprise/storage/org_llm_settings_store.py
Normal file
83
enterprise/storage/org_llm_settings_store.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Store class for managing organization LLM settings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from server.routes.org_models import OrgLLMSettingsUpdate
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.org import Org
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrgLLMSettingsStore:
|
||||
"""Store for org LLM settings with injected db_session."""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def get_current_org_by_user_id(self, user_id: str) -> Org | None:
|
||||
"""Get the user's current organization.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID (Keycloak user ID)
|
||||
|
||||
Returns:
|
||||
Org: The user's current organization, or None if not found
|
||||
"""
|
||||
# First get the user to find their current_org_id
|
||||
result = await self.db_session.execute(
|
||||
select(User).filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if not user or not user.current_org_id:
|
||||
return None
|
||||
|
||||
# Then get the org
|
||||
result = await self.db_session.execute(
|
||||
select(Org).filter(Org.id == user.current_org_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def update_org_llm_settings(
|
||||
self, org_id: UUID, update_data: OrgLLMSettingsUpdate
|
||||
) -> Org | None:
|
||||
"""Update organization LLM settings.
|
||||
|
||||
Also propagates relevant settings to all org members.
|
||||
Uses flush() - commit happens at request end via DbSessionInjector.
|
||||
|
||||
Args:
|
||||
org_id: The organization's ID
|
||||
update_data: Pydantic model with fields to update
|
||||
|
||||
Returns:
|
||||
Org: The updated organization, or None if org not found
|
||||
"""
|
||||
result = await self.db_session.execute(
|
||||
select(Org).filter(Org.id == org_id).with_for_update()
|
||||
)
|
||||
org = result.scalars().first()
|
||||
|
||||
if not org:
|
||||
return None
|
||||
|
||||
# Apply updates to org (excludes llm_api_key which is member-only)
|
||||
update_data.apply_to_org(org)
|
||||
|
||||
# Propagate relevant settings to all org members
|
||||
member_updates = update_data.get_member_updates()
|
||||
if member_updates:
|
||||
await OrgMemberStore.update_all_members_llm_settings_async(
|
||||
self.db_session, org_id, member_updates
|
||||
)
|
||||
|
||||
# flush instead of commit - DbSessionInjector auto-commits at request end
|
||||
await self.db_session.flush()
|
||||
await self.db_session.refresh(org)
|
||||
return org
|
||||
@@ -5,9 +5,12 @@ Store class for managing organization-member relationships.
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from server.routes.org_models import OrgMemberLLMSettings
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.encrypt_utils import encrypt_value
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
@@ -254,3 +257,28 @@ class OrgMemberStore:
|
||||
members = members[:limit]
|
||||
|
||||
return members, has_more
|
||||
|
||||
@staticmethod
|
||||
async def update_all_members_llm_settings_async(
|
||||
session: AsyncSession,
|
||||
org_id: UUID,
|
||||
member_settings: OrgMemberLLMSettings,
|
||||
) -> None:
|
||||
"""Update LLM settings for all members of an organization.
|
||||
|
||||
Args:
|
||||
session: Database session (passed from caller for transaction)
|
||||
org_id: Organization ID
|
||||
member_settings: Typed LLM settings to apply to all members
|
||||
"""
|
||||
# Build update values from non-None fields
|
||||
values = member_settings.model_dump(exclude_none=True)
|
||||
|
||||
# Handle encrypted llm_api_key field - map to _llm_api_key column with encryption
|
||||
if 'llm_api_key' in values:
|
||||
raw_key = values.pop('llm_api_key')
|
||||
values['_llm_api_key'] = encrypt_value(raw_key)
|
||||
|
||||
if values:
|
||||
stmt = update(OrgMember).where(OrgMember.org_id == org_id).values(**values)
|
||||
await session.execute(stmt)
|
||||
|
||||
@@ -3,6 +3,7 @@ Service class for managing organization operations.
|
||||
Separates business logic from route handlers.
|
||||
"""
|
||||
|
||||
from typing import NoReturn
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
@@ -325,7 +326,7 @@ class OrgService:
|
||||
user_id: str,
|
||||
original_error: Exception,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
) -> NoReturn:
|
||||
"""
|
||||
Handle failure by cleaning up LiteLLM resources and raising appropriate error.
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ from server.constants import (
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.routes.org_models import OrphanedUserError
|
||||
from sqlalchemy import text
|
||||
from server.routes.org_models import OrgLLMSettingsUpdate, OrphanedUserError
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
@@ -386,3 +386,47 @@ class OrgStore:
|
||||
extra={'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def get_org_by_id_async(org_id: UUID) -> Org | None:
|
||||
"""Get organization by ID (async version)."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
return OrgStore._validate_org_version(org) if org else None
|
||||
|
||||
@staticmethod
|
||||
async def update_org_llm_settings_async(
|
||||
org_id: UUID,
|
||||
llm_settings: OrgLLMSettingsUpdate,
|
||||
) -> Org | None:
|
||||
"""Update organization LLM settings and propagate to members (async version).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
llm_settings: Typed LLM settings update model
|
||||
|
||||
Returns:
|
||||
Updated Org or None if not found
|
||||
"""
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
# Apply updates to org
|
||||
llm_settings.apply_to_org(org)
|
||||
|
||||
# Propagate relevant settings to all org members
|
||||
member_updates = llm_settings.get_member_updates()
|
||||
if member_updates:
|
||||
await OrgMemberStore.update_all_members_llm_settings_async(
|
||||
session, org_id, member_updates
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
return org
|
||||
|
||||
@@ -10,7 +10,6 @@ from integrations.github.github_types import (
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import a_session_maker
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
|
||||
@@ -20,8 +19,6 @@ from openhands.integrations.service_types import ProviderType
|
||||
|
||||
@dataclass
|
||||
class ProactiveConversationStore:
|
||||
a_session_maker: sessionmaker = a_session_maker
|
||||
|
||||
def get_repo_id(self, provider: ProviderType, repo_id):
|
||||
return f'{provider.value}##{repo_id}'
|
||||
|
||||
@@ -51,7 +48,7 @@ class ProactiveConversationStore:
|
||||
|
||||
final_workflow_group = None
|
||||
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Start an explicit transaction with row-level locking
|
||||
async with session.begin():
|
||||
# Get the existing proactive conversation entry with FOR UPDATE lock
|
||||
@@ -142,7 +139,7 @@ class ProactiveConversationStore:
|
||||
# Calculate the cutoff time (current time - older_than_minutes)
|
||||
cutoff_time = datetime.now(UTC) - timedelta(minutes=older_than_minutes)
|
||||
|
||||
async with self.a_session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
async with session.begin():
|
||||
# Delete records older than the cutoff time
|
||||
delete_stmt = delete(ProactiveConversation).where(
|
||||
@@ -158,9 +155,9 @@ class ProactiveConversationStore:
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> ProactiveConversationStore:
|
||||
"""Get an instance of the GitlabWebhookStore.
|
||||
"""Get an instance of the ProactiveConversationStore.
|
||||
|
||||
Returns:
|
||||
An instance of GitlabWebhookStore
|
||||
An instance of ProactiveConversationStore
|
||||
"""
|
||||
return ProactiveConversationStore(a_session_maker)
|
||||
return ProactiveConversationStore()
|
||||
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.stored_repository import StoredRepository
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -11,12 +11,11 @@ from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
|
||||
@dataclass
|
||||
class RepositoryStore:
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
|
||||
def store_projects(self, repositories: list[StoredRepository]) -> None:
|
||||
async def store_projects(self, repositories: list[StoredRepository]) -> None:
|
||||
"""
|
||||
Store repositories in database
|
||||
Store repositories in database (async version)
|
||||
|
||||
1. Make sure to store repositories if its ID doesn't exist
|
||||
2. If repository ID already exists, make sure to only update the repo is_public and repo_name fields
|
||||
@@ -26,17 +25,15 @@ class RepositoryStore:
|
||||
if not repositories:
|
||||
return
|
||||
|
||||
with self.session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Extract all repo_ids to check
|
||||
repo_ids = [r.repo_id for r in repositories]
|
||||
|
||||
# Get all existing repositories in a single query
|
||||
existing_repos = {
|
||||
r.repo_id: r
|
||||
for r in session.query(StoredRepository).filter(
|
||||
StoredRepository.repo_id.in_(repo_ids)
|
||||
)
|
||||
}
|
||||
result = await session.execute(
|
||||
select(StoredRepository).filter(StoredRepository.repo_id.in_(repo_ids))
|
||||
)
|
||||
existing_repos = {r.repo_id: r for r in result.scalars().all()}
|
||||
|
||||
# Process all repositories
|
||||
for repo in repositories:
|
||||
@@ -50,9 +47,9 @@ class RepositoryStore:
|
||||
session.add(repo)
|
||||
|
||||
# Commit all changes
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, config: OpenHandsConfig) -> RepositoryStore:
|
||||
"""Get an instance of the UserRepositoryStore."""
|
||||
return RepositoryStore(session_maker, config)
|
||||
return RepositoryStore(config)
|
||||
|
||||
@@ -234,6 +234,8 @@ class SaasConversationStore(ConversationStore):
|
||||
cls, config: OpenHandsConfig, user_id: str | None
|
||||
) -> ConversationStore:
|
||||
# user_id should not be None in SaaS, should we raise?
|
||||
# Use async version since callers now use asyncio.run_coroutine_threadsafe()
|
||||
# to dispatch to the main event loop where asyncpg connections work properly.
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
return SaasConversationStore(str(user_id), org_id, session_maker)
|
||||
|
||||
@@ -28,7 +28,7 @@ class SaasConversationValidator(ConversationValidator):
|
||||
|
||||
# Validate the API key and get the user_id
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
user_id = api_key_store.validate_api_key(api_key)
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
|
||||
if not user_id:
|
||||
logger.warning('Invalid API key')
|
||||
|
||||
@@ -5,8 +5,8 @@ from base64 import b64decode, b64encode
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import delete, select
|
||||
from storage.database import a_session_maker
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.user_store import UserStore
|
||||
|
||||
@@ -19,7 +19,6 @@ from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
@dataclass
|
||||
class SaasSecretsStore(SecretsStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
|
||||
async def load(self) -> Secrets | None:
|
||||
@@ -28,14 +27,15 @@ class SaasSecretsStore(SecretsStore):
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
|
||||
with self.session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Fetch all secrets for the given user ID
|
||||
query = session.query(StoredCustomSecrets).filter(
|
||||
query = select(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
if org_id is not None:
|
||||
query = query.filter(StoredCustomSecrets.org_id == org_id)
|
||||
settings = query.all()
|
||||
result = await session.execute(query)
|
||||
settings = result.scalars().all()
|
||||
|
||||
if not settings:
|
||||
return Secrets()
|
||||
@@ -54,12 +54,15 @@ class SaasSecretsStore(SecretsStore):
|
||||
async def store(self, item: Secrets):
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
|
||||
async with a_session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete all existing records and override with incoming ones
|
||||
session.query(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
).delete()
|
||||
await session.execute(
|
||||
delete(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare the new secrets data
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
@@ -89,7 +92,7 @@ class SaasSecretsStore(SecretsStore):
|
||||
)
|
||||
session.add(new_secret)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
def _decrypt_kwargs(self, kwargs: dict):
|
||||
fernet = self._fernet()
|
||||
@@ -133,4 +136,4 @@ class SaasSecretsStore(SecretsStore):
|
||||
if not user_id:
|
||||
raise Exception('SaasSecretsStore cannot be constructed with no user_id')
|
||||
logger.debug(f'saas_secrets_store.get_instance::{user_id}')
|
||||
return SaasSecretsStore(user_id, session_maker, config)
|
||||
return SaasSecretsStore(user_id, config)
|
||||
|
||||
@@ -10,8 +10,9 @@ from cryptography.fernet import Fernet
|
||||
from pydantic import SecretStr
|
||||
from server.constants import LITE_LLM_API_URL
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import joinedload, sessionmaker
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager, get_openhands_cloud_key_alias
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
@@ -23,26 +24,24 @@ from storage.user_store import UserStore
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.llm import is_openhands_model
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaasSettingsStore(SettingsStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
|
||||
|
||||
def _get_user_settings_by_keycloak_id(
|
||||
async def _get_user_settings_by_keycloak_id_async(
|
||||
self, keycloak_user_id: str, session=None
|
||||
) -> UserSettings | None:
|
||||
"""
|
||||
Get UserSettings by keycloak_user_id.
|
||||
Get UserSettings by keycloak_user_id (async version).
|
||||
|
||||
Args:
|
||||
keycloak_user_id: The keycloak user ID to search for
|
||||
session: Optional existing database session. If not provided, creates a new one.
|
||||
session: Optional existing async database session. If not provided, creates a new one.
|
||||
|
||||
Returns:
|
||||
UserSettings object if found, None otherwise
|
||||
@@ -50,27 +49,26 @@ class SaasSettingsStore(SettingsStore):
|
||||
if not keycloak_user_id:
|
||||
return None
|
||||
|
||||
def _get_settings():
|
||||
if session:
|
||||
# Use provided session
|
||||
return (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
if session:
|
||||
# Use provided session
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == keycloak_user_id
|
||||
)
|
||||
else:
|
||||
# Create new session
|
||||
with self.session_maker() as new_session:
|
||||
return (
|
||||
new_session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
)
|
||||
return result.scalars().first()
|
||||
else:
|
||||
# Create new session
|
||||
async with a_session_maker() as new_session:
|
||||
result = await new_session.execute(
|
||||
select(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == keycloak_user_id
|
||||
)
|
||||
|
||||
return _get_settings()
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def load(self) -> Settings | None:
|
||||
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
|
||||
user = await UserStore.get_user_by_id_async(self.user_id)
|
||||
if not user:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
return None
|
||||
@@ -83,7 +81,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
org = await OrgStore.get_org_by_id_async(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
@@ -122,21 +120,22 @@ class SaasSettingsStore(SettingsStore):
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
with self.session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
if not item:
|
||||
return None
|
||||
user = (
|
||||
session.query(User)
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(self.user_id))
|
||||
).first()
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if not user:
|
||||
# Check if we need to migrate from user_settings
|
||||
user_settings = None
|
||||
with session_maker() as session:
|
||||
user_settings = self._get_user_settings_by_keycloak_id(
|
||||
self.user_id, session
|
||||
async with a_session_maker() as new_session:
|
||||
user_settings = await self._get_user_settings_by_keycloak_id_async(
|
||||
self.user_id, new_session
|
||||
)
|
||||
if user_settings:
|
||||
user = await UserStore.migrate_user(self.user_id, user_settings)
|
||||
@@ -154,7 +153,8 @@ class SaasSettingsStore(SettingsStore):
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
|
||||
org: Org = session.query(Org).filter(Org.id == org_id).first()
|
||||
result = await session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
@@ -173,7 +173,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
@@ -182,7 +182,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
user_id: str, # type: ignore[override]
|
||||
) -> SaasSettingsStore:
|
||||
logger.debug(f'saas_settings_store.get_instance::{user_id}')
|
||||
return SaasSettingsStore(user_id, session_maker, config)
|
||||
return SaasSettingsStore(user_id, config)
|
||||
|
||||
def _should_encrypt(self, key):
|
||||
return key in self.ENCRYPT_VALUES
|
||||
|
||||
@@ -2,38 +2,35 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.slack_conversation import SlackConversation
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackConversationStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
async def get_slack_conversation(
|
||||
self, channel_id: str, parent_id: str
|
||||
) -> SlackConversation | None:
|
||||
"""Get a slack conversation by channel_id and message_ts.
|
||||
Both parameters are required to match for a conversation to be returned.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
conversation = (
|
||||
session.query(SlackConversation)
|
||||
.filter(SlackConversation.channel_id == channel_id)
|
||||
.filter(SlackConversation.parent_id == parent_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SlackConversation).where(
|
||||
SlackConversation.channel_id == channel_id,
|
||||
SlackConversation.parent_id == parent_id,
|
||||
)
|
||||
)
|
||||
|
||||
return conversation
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create_slack_conversation(
|
||||
self, slack_converstion: SlackConversation
|
||||
) -> None:
|
||||
with self.session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.merge(slack_converstion)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> SlackConversationStore:
|
||||
return SlackConversationStore(session_maker)
|
||||
return SlackConversationStore()
|
||||
|
||||
@@ -32,6 +32,7 @@ class SlackTeamStore:
|
||||
# Store the token
|
||||
session.add(slack_team)
|
||||
session.commit()
|
||||
return slack_team
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -12,12 +12,11 @@ from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
|
||||
@dataclass
|
||||
class UserRepositoryMapStore:
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
|
||||
def store_user_repo_mappings(self, mappings: list[UserRepositoryMap]) -> None:
|
||||
async def store_user_repo_mappings(self, mappings: list[UserRepositoryMap]) -> None:
|
||||
"""
|
||||
Store user-repository mappings in database
|
||||
Store user-repository mappings in database (async version)
|
||||
|
||||
1. Make sure to store mappings if they don't exist
|
||||
2. If a mapping already exists (same user_id and repo_id), update the admin field
|
||||
@@ -30,18 +29,20 @@ class UserRepositoryMapStore:
|
||||
if not mappings:
|
||||
return
|
||||
|
||||
with self.session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Extract all user_id/repo_id pairs to check
|
||||
mapping_keys = [(m.user_id, m.repo_id) for m in mappings]
|
||||
|
||||
# Get all existing mappings in a single query
|
||||
existing_mappings = {
|
||||
(m.user_id, m.repo_id): m
|
||||
for m in session.query(UserRepositoryMap).filter(
|
||||
result = await session.execute(
|
||||
select(UserRepositoryMap).filter(
|
||||
sqlalchemy.tuple_(
|
||||
UserRepositoryMap.user_id, UserRepositoryMap.repo_id
|
||||
).in_(mapping_keys)
|
||||
)
|
||||
)
|
||||
existing_mappings = {
|
||||
(m.user_id, m.repo_id): m for m in result.scalars().all()
|
||||
}
|
||||
|
||||
# Process all mappings
|
||||
@@ -56,9 +57,9 @@ class UserRepositoryMapStore:
|
||||
session.add(mapping)
|
||||
|
||||
# Commit all changes
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, config: OpenHandsConfig) -> UserRepositoryMapStore:
|
||||
"""Get an instance of the UserRepositoryMapStore."""
|
||||
return UserRepositoryMapStore(session_maker, config)
|
||||
return UserRepositoryMapStore(config)
|
||||
|
||||
@@ -227,7 +227,7 @@ class UserStore:
|
||||
'user_store:migrate_user:calling_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await migrate_customer(session, user_id, org)
|
||||
await migrate_customer(user_id, org)
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
"""SQLAlchemy model for verified LLM models."""
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Identity,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
text,
|
||||
)
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class VerifiedModel(Base): # type: ignore
|
||||
"""A verified LLM model available in the model selector.
|
||||
|
||||
The composite unique constraint on (model_name, provider) allows the same
|
||||
model name to exist under different providers (e.g. 'claude-sonnet' under
|
||||
both 'openhands' and 'anthropic').
|
||||
"""
|
||||
|
||||
__tablename__ = 'verified_models'
|
||||
__table_args__ = (
|
||||
UniqueConstraint('model_name', 'provider', name='uq_verified_model_provider'),
|
||||
)
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
model_name = Column(String(255), nullable=False)
|
||||
provider = Column(String(100), nullable=False, index=True)
|
||||
is_enabled = Column(
|
||||
Boolean, nullable=False, default=True, server_default=text('true')
|
||||
)
|
||||
created_at = Column(DateTime, nullable=False, server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime, nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -1,187 +0,0 @@
|
||||
"""Store for managing verified LLM models in the database."""
|
||||
|
||||
from sqlalchemy import and_
|
||||
from storage.database import session_maker
|
||||
from storage.verified_model import VerifiedModel
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class VerifiedModelStore:
|
||||
"""Store for CRUD operations on verified models.
|
||||
|
||||
Follows the project convention of static methods with session_maker()
|
||||
(see UserStore, OrgMemberStore for reference).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_enabled_models() -> list[VerifiedModel]:
|
||||
"""Get all enabled models.
|
||||
|
||||
Returns:
|
||||
list[VerifiedModel]: All models where is_enabled is True
|
||||
"""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(VerifiedModel)
|
||||
.filter(VerifiedModel.is_enabled.is_(True))
|
||||
.order_by(VerifiedModel.provider, VerifiedModel.model_name)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_models_by_provider(provider: str) -> list[VerifiedModel]:
|
||||
"""Get all enabled models for a specific provider.
|
||||
|
||||
Args:
|
||||
provider: The provider name (e.g., 'openhands', 'anthropic')
|
||||
"""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(VerifiedModel)
|
||||
.filter(
|
||||
and_(
|
||||
VerifiedModel.provider == provider,
|
||||
VerifiedModel.is_enabled.is_(True),
|
||||
)
|
||||
)
|
||||
.order_by(VerifiedModel.model_name)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all_models() -> list[VerifiedModel]:
|
||||
"""Get all models (including disabled)."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(VerifiedModel)
|
||||
.order_by(VerifiedModel.provider, VerifiedModel.model_name)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_model(model_name: str, provider: str) -> VerifiedModel | None:
|
||||
"""Get a model by its composite key (model_name, provider).
|
||||
|
||||
Args:
|
||||
model_name: The model identifier
|
||||
provider: The provider name
|
||||
"""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(VerifiedModel)
|
||||
.filter(
|
||||
and_(
|
||||
VerifiedModel.model_name == model_name,
|
||||
VerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_model(
|
||||
model_name: str, provider: str, is_enabled: bool = True
|
||||
) -> VerifiedModel:
|
||||
"""Create a new verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model identifier
|
||||
provider: The provider name
|
||||
is_enabled: Whether the model is enabled (default True)
|
||||
|
||||
Raises:
|
||||
ValueError: If a model with the same (model_name, provider) already exists
|
||||
"""
|
||||
with session_maker() as session:
|
||||
existing = (
|
||||
session.query(VerifiedModel)
|
||||
.filter(
|
||||
and_(
|
||||
VerifiedModel.model_name == model_name,
|
||||
VerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f'Model {provider}/{model_name} already exists')
|
||||
|
||||
model = VerifiedModel(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
is_enabled=is_enabled,
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
logger.info(f'Created verified model: {provider}/{model_name}')
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def update_model(
|
||||
model_name: str,
|
||||
provider: str,
|
||||
is_enabled: bool | None = None,
|
||||
) -> VerifiedModel | None:
|
||||
"""Update an existing verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to update
|
||||
provider: The provider name
|
||||
is_enabled: New enabled state (optional)
|
||||
|
||||
Returns:
|
||||
The updated model if found, None otherwise
|
||||
"""
|
||||
with session_maker() as session:
|
||||
model = (
|
||||
session.query(VerifiedModel)
|
||||
.filter(
|
||||
and_(
|
||||
VerifiedModel.model_name == model_name,
|
||||
VerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
return None
|
||||
|
||||
if is_enabled is not None:
|
||||
model.is_enabled = is_enabled
|
||||
|
||||
session.commit()
|
||||
session.refresh(model)
|
||||
logger.info(f'Updated verified model: {provider}/{model_name}')
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def delete_model(model_name: str, provider: str) -> bool:
|
||||
"""Delete a verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to delete
|
||||
provider: The provider name
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
with session_maker() as session:
|
||||
model = (
|
||||
session.query(VerifiedModel)
|
||||
.filter(
|
||||
and_(
|
||||
VerifiedModel.model_name == model_name,
|
||||
VerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
return False
|
||||
|
||||
session.delete(model)
|
||||
session.commit()
|
||||
logger.info(f'Deleted verified model: {provider}/{model_name}')
|
||||
return True
|
||||
@@ -4,11 +4,20 @@ from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.constants import ORG_SETTINGS_VERSION
|
||||
from server.verified_models.verified_model_service import (
|
||||
StoredVerifiedModel, # noqa: F401
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
|
||||
# Anything not loaded here may not have a table created for it.
|
||||
from storage.api_key import ApiKey # noqa: F401
|
||||
from storage.base import Base
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.device_code import DeviceCode # noqa: F401
|
||||
@@ -25,12 +34,20 @@ from storage.stored_conversation_metadata_saas import (
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user import User
|
||||
from storage.verified_model import VerifiedModel # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def db_path(tmp_path):
|
||||
"""Create a unique temp file path for each test."""
|
||||
return str(tmp_path / 'test.db')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
def engine(db_path):
|
||||
"""Create a sync engine with tables using file-based DB."""
|
||||
engine = create_engine(
|
||||
f'sqlite:///{db_path}', connect_args={'check_same_thread': False}
|
||||
)
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
@@ -40,6 +57,36 @@ def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_engine(db_path):
|
||||
"""Create an async engine using the SAME file-based database."""
|
||||
async_engine = create_async_engine(
|
||||
f'sqlite+aiosqlite:///{db_path}',
|
||||
connect_args={'check_same_thread': False},
|
||||
)
|
||||
|
||||
async def create_tables():
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Run the async function synchronously
|
||||
import asyncio
|
||||
|
||||
asyncio.run(create_tables())
|
||||
return async_engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker bound to the async engine."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
return async_session_maker
|
||||
|
||||
|
||||
def add_minimal_fixtures(session_maker):
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
|
||||
@@ -7,7 +7,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from integrations.jira.jira_manager import JiraManager
|
||||
from integrations.jira.jira_payload import JiraEventType, JiraWebhookPayload
|
||||
from integrations.models import Message, SourceType
|
||||
|
||||
from openhands.server.types import (
|
||||
LLMAuthenticationError,
|
||||
@@ -274,9 +273,8 @@ class TestSendMessage:
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
message = Message(source=SourceType.JIRA, message='Test message')
|
||||
result = await jira_manager.send_message(
|
||||
message,
|
||||
'Test message',
|
||||
'PROJ-123',
|
||||
'cloud-123',
|
||||
'service@test.com',
|
||||
|
||||
268
enterprise/tests/unit/integrations/jira/test_jira_payload.py
Normal file
268
enterprise/tests/unit/integrations/jira/test_jira_payload.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
Tests for JiraPayloadParser.
|
||||
|
||||
These tests verify the parsing behavior of Jira webhook payloads,
|
||||
including the handling of optional fields like user_email which
|
||||
may not be present in webhook payloads from Jira.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from integrations.jira.jira_payload import (
|
||||
JiraEventType,
|
||||
JiraPayloadError,
|
||||
JiraPayloadParser,
|
||||
JiraPayloadSkipped,
|
||||
JiraPayloadSuccess,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
"""Create a JiraPayloadParser with standard OpenHands labels."""
|
||||
return JiraPayloadParser(oh_label='openhands', inline_oh_label='@openhands')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_label_payload():
|
||||
"""Create a valid jira:issue_updated payload with OpenHands label."""
|
||||
return {
|
||||
'webhookEvent': 'jira:issue_updated',
|
||||
'issue': {
|
||||
'id': '12345',
|
||||
'key': 'TEST-123',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/issue/12345',
|
||||
},
|
||||
'user': {
|
||||
'displayName': 'Test User',
|
||||
'accountId': 'account-123',
|
||||
'emailAddress': 'test@example.com',
|
||||
},
|
||||
'changelog': {
|
||||
'items': [
|
||||
{
|
||||
'field': 'labels',
|
||||
'toString': 'openhands',
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_comment_payload():
|
||||
"""Create a valid comment_created payload with OpenHands mention."""
|
||||
return {
|
||||
'webhookEvent': 'comment_created',
|
||||
'issue': {
|
||||
'id': '12345',
|
||||
'key': 'TEST-123',
|
||||
'self': 'https://test.atlassian.net/rest/api/2/issue/12345',
|
||||
},
|
||||
'comment': {
|
||||
'body': '@openhands please fix this bug',
|
||||
'author': {
|
||||
'displayName': 'Test User',
|
||||
'accountId': 'account-123',
|
||||
'emailAddress': 'test@example.com',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestUserEmailOptional:
|
||||
"""Tests verifying user_email is optional in webhook payloads.
|
||||
|
||||
Jira webhooks may not include emailAddress in the user data.
|
||||
The parser should accept payloads without this field.
|
||||
"""
|
||||
|
||||
def test_label_event_succeeds_without_email_address(
|
||||
self, parser, valid_label_payload
|
||||
):
|
||||
"""Verify label event parsing succeeds when emailAddress is missing."""
|
||||
# Arrange - remove emailAddress from user data
|
||||
del valid_label_payload['user']['emailAddress']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.user_email == ''
|
||||
assert result.payload.display_name == 'Test User'
|
||||
assert result.payload.account_id == 'account-123'
|
||||
|
||||
def test_comment_event_succeeds_without_email_address(
|
||||
self, parser, valid_comment_payload
|
||||
):
|
||||
"""Verify comment event parsing succeeds when emailAddress is missing."""
|
||||
# Arrange - remove emailAddress from author data
|
||||
del valid_comment_payload['comment']['author']['emailAddress']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_comment_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.user_email == ''
|
||||
assert result.payload.display_name == 'Test User'
|
||||
assert result.payload.account_id == 'account-123'
|
||||
|
||||
def test_user_email_preserved_when_present(self, parser, valid_label_payload):
|
||||
"""Verify user_email is captured when emailAddress is present."""
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.user_email == 'test@example.com'
|
||||
|
||||
|
||||
class TestRequiredFieldValidation:
|
||||
"""Tests verifying required fields are still validated."""
|
||||
|
||||
def test_missing_issue_id_returns_error(self, parser, valid_label_payload):
|
||||
"""Verify parsing fails when issue.id is missing."""
|
||||
# Arrange
|
||||
del valid_label_payload['issue']['id']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'issue.id' in result.error
|
||||
|
||||
def test_missing_issue_key_returns_error(self, parser, valid_label_payload):
|
||||
"""Verify parsing fails when issue.key is missing."""
|
||||
# Arrange
|
||||
del valid_label_payload['issue']['key']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'issue.key' in result.error
|
||||
|
||||
def test_missing_display_name_returns_error(self, parser, valid_label_payload):
|
||||
"""Verify parsing fails when user.displayName is missing."""
|
||||
# Arrange
|
||||
del valid_label_payload['user']['displayName']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'displayName' in result.error
|
||||
|
||||
def test_missing_account_id_returns_error(self, parser, valid_label_payload):
|
||||
"""Verify parsing fails when user.accountId is missing."""
|
||||
# Arrange
|
||||
del valid_label_payload['user']['accountId']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'accountId' in result.error
|
||||
|
||||
def test_missing_issue_self_url_returns_error(self, parser, valid_label_payload):
|
||||
"""Verify parsing fails when issue.self URL is missing."""
|
||||
# Arrange
|
||||
del valid_label_payload['issue']['self']
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadError)
|
||||
assert 'workspace_name' in result.error or 'base_api_url' in result.error
|
||||
|
||||
|
||||
class TestEventTypeDetection:
|
||||
"""Tests for webhook event type detection."""
|
||||
|
||||
def test_issue_updated_with_label_returns_labeled_ticket(
|
||||
self, parser, valid_label_payload
|
||||
):
|
||||
"""Verify jira:issue_updated with label is detected as LABELED_TICKET."""
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.event_type == JiraEventType.LABELED_TICKET
|
||||
|
||||
def test_comment_created_with_mention_returns_comment_mention(
|
||||
self, parser, valid_comment_payload
|
||||
):
|
||||
"""Verify comment_created with mention is detected as COMMENT_MENTION."""
|
||||
# Act
|
||||
result = parser.parse(valid_comment_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.event_type == JiraEventType.COMMENT_MENTION
|
||||
|
||||
def test_unhandled_event_type_returns_skipped(self, parser):
|
||||
"""Verify unknown event types are skipped."""
|
||||
# Arrange
|
||||
payload = {'webhookEvent': 'jira:issue_deleted'}
|
||||
|
||||
# Act
|
||||
result = parser.parse(payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert 'Unhandled' in result.skip_reason
|
||||
|
||||
|
||||
class TestLabelFiltering:
|
||||
"""Tests for OpenHands label filtering."""
|
||||
|
||||
def test_label_event_without_openhands_label_skipped(
|
||||
self, parser, valid_label_payload
|
||||
):
|
||||
"""Verify label events without OpenHands label are skipped."""
|
||||
# Arrange - change label to something else
|
||||
valid_label_payload['changelog']['items'][0]['toString'] = 'other-label'
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert 'openhands' in result.skip_reason
|
||||
|
||||
|
||||
class TestCommentFiltering:
|
||||
"""Tests for OpenHands comment mention filtering."""
|
||||
|
||||
def test_comment_without_mention_skipped(self, parser, valid_comment_payload):
|
||||
"""Verify comments without OpenHands mention are skipped."""
|
||||
# Arrange - remove mention from comment body
|
||||
valid_comment_payload['comment']['body'] = 'Please fix this bug'
|
||||
|
||||
# Act
|
||||
result = parser.parse(valid_comment_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
assert '@openhands' in result.skip_reason
|
||||
|
||||
|
||||
class TestWorkspaceExtraction:
|
||||
"""Tests for workspace name extraction from issue URL."""
|
||||
|
||||
def test_workspace_name_extracted_from_self_url(self, parser, valid_label_payload):
|
||||
"""Verify workspace name is extracted from issue self URL."""
|
||||
# Act
|
||||
result = parser.parse(valid_label_payload)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JiraPayloadSuccess)
|
||||
assert result.payload.workspace_name == 'test.atlassian.net'
|
||||
assert result.payload.base_api_url == 'https://test.atlassian.net'
|
||||
@@ -738,7 +738,7 @@ class TestStartJob:
|
||||
# Should send error message about re-login
|
||||
jira_dc_manager.send_message.assert_called_once()
|
||||
call_args = jira_dc_manager.send_message.call_args[0]
|
||||
assert 'Please re-login' in call_args[0].message
|
||||
assert 'Please re-login' in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_llm_authentication_error(
|
||||
@@ -763,7 +763,7 @@ class TestStartJob:
|
||||
# Should send error message about LLM API key
|
||||
jira_dc_manager.send_message.assert_called_once()
|
||||
call_args = jira_dc_manager.send_message.call_args[0]
|
||||
assert 'valid LLM API key' in call_args[0].message
|
||||
assert 'valid LLM API key' in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_session_expired_error(
|
||||
@@ -788,8 +788,8 @@ class TestStartJob:
|
||||
# Should send error message about session expired
|
||||
jira_dc_manager.send_message.assert_called_once()
|
||||
call_args = jira_dc_manager.send_message.call_args[0]
|
||||
assert 'session has expired' in call_args[0].message
|
||||
assert 'login again' in call_args[0].message
|
||||
assert 'session has expired' in call_args[0]
|
||||
assert 'login again' in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_unexpected_error(
|
||||
@@ -814,7 +814,7 @@ class TestStartJob:
|
||||
# Should send generic error message
|
||||
jira_dc_manager.send_message.assert_called_once()
|
||||
call_args = jira_dc_manager.send_message.call_args[0]
|
||||
assert 'unexpected error' in call_args[0].message
|
||||
assert 'unexpected error' in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_send_message_fails(
|
||||
@@ -943,9 +943,8 @@ class TestSendMessage:
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
message = Message(source=SourceType.JIRA_DC, message='Test message')
|
||||
result = await jira_dc_manager.send_message(
|
||||
message, 'PROJ-123', 'https://jira.company.com', 'bearer_token'
|
||||
'Test message', 'PROJ-123', 'https://jira.company.com', 'bearer_token'
|
||||
)
|
||||
|
||||
assert result == {'id': 'comment_id'}
|
||||
@@ -1014,7 +1013,7 @@ class TestSendRepoSelectionComment:
|
||||
|
||||
jira_dc_manager.send_message.assert_called_once()
|
||||
call_args = jira_dc_manager.send_message.call_args[0]
|
||||
assert 'which repository to work with' in call_args[0].message
|
||||
assert 'which repository to work with' in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_repo_selection_comment_send_fails(
|
||||
|
||||
@@ -18,9 +18,11 @@ from openhands.core.schema.agent import AgentState
|
||||
class TestJiraDcNewConversationView:
|
||||
"""Tests for JiraDcNewConversationView"""
|
||||
|
||||
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
|
||||
instructions, user_msg = await new_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
assert instructions == 'Test Jira DC instructions template'
|
||||
assert 'PROJ-123' in user_msg
|
||||
@@ -83,9 +85,9 @@ class TestJiraDcNewConversationView:
|
||||
class TestJiraDcExistingConversationView:
|
||||
"""Tests for JiraDcExistingConversationView"""
|
||||
|
||||
def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
|
||||
async def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = existing_conversation_view._get_instructions(
|
||||
instructions, user_msg = await existing_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
|
||||
@@ -802,7 +802,7 @@ class TestStartJob:
|
||||
# Should send error message about re-login
|
||||
linear_manager.send_message.assert_called_once()
|
||||
call_args = linear_manager.send_message.call_args[0]
|
||||
assert 'Please re-login' in call_args[0].message
|
||||
assert 'Please re-login' in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_llm_authentication_error(
|
||||
@@ -828,7 +828,7 @@ class TestStartJob:
|
||||
# Should send error message about LLM API key
|
||||
linear_manager.send_message.assert_called_once()
|
||||
call_args = linear_manager.send_message.call_args[0]
|
||||
assert 'valid LLM API key' in call_args[0].message
|
||||
assert 'valid LLM API key' in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_session_expired_error(
|
||||
@@ -854,8 +854,8 @@ class TestStartJob:
|
||||
# Should send error message about session expired
|
||||
linear_manager.send_message.assert_called_once()
|
||||
call_args = linear_manager.send_message.call_args[0]
|
||||
assert 'session has expired' in call_args[0].message
|
||||
assert 'login again' in call_args[0].message
|
||||
assert 'session has expired' in call_args[0]
|
||||
assert 'login again' in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_unexpected_error(
|
||||
@@ -881,7 +881,7 @@ class TestStartJob:
|
||||
# Should send generic error message
|
||||
linear_manager.send_message.assert_called_once()
|
||||
call_args = linear_manager.send_message.call_args[0]
|
||||
assert 'unexpected error' in call_args[0].message
|
||||
assert 'unexpected error' in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_send_message_fails(
|
||||
@@ -1049,8 +1049,9 @@ class TestSendMessage:
|
||||
|
||||
linear_manager._query_api = AsyncMock(return_value=mock_response)
|
||||
|
||||
message = Message(source=SourceType.LINEAR, message='Test message')
|
||||
result = await linear_manager.send_message(message, 'issue_id', 'api_key')
|
||||
result = await linear_manager.send_message(
|
||||
'Test message', 'issue_id', 'api_key'
|
||||
)
|
||||
|
||||
assert result == mock_response
|
||||
linear_manager._query_api.assert_called_once()
|
||||
@@ -1114,7 +1115,7 @@ class TestSendRepoSelectionComment:
|
||||
|
||||
linear_manager.send_message.assert_called_once()
|
||||
call_args = linear_manager.send_message.call_args[0]
|
||||
assert 'which repository to work with' in call_args[0].message
|
||||
assert 'which repository to work with' in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_repo_selection_comment_send_fails(
|
||||
|
||||
@@ -18,9 +18,11 @@ from openhands.core.schema.agent import AgentState
|
||||
class TestLinearNewConversationView:
|
||||
"""Tests for LinearNewConversationView"""
|
||||
|
||||
def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
async def test_get_instructions(self, new_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = new_conversation_view._get_instructions(mock_jinja_env)
|
||||
instructions, user_msg = await new_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
assert instructions == 'Test instructions template'
|
||||
assert 'TEST-123' in user_msg
|
||||
@@ -83,9 +85,9 @@ class TestLinearNewConversationView:
|
||||
class TestLinearExistingConversationView:
|
||||
"""Tests for LinearExistingConversationView"""
|
||||
|
||||
def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
|
||||
async def test_get_instructions(self, existing_conversation_view, mock_jinja_env):
|
||||
"""Test _get_instructions method"""
|
||||
instructions, user_msg = existing_conversation_view._get_instructions(
|
||||
instructions, user_msg = await existing_conversation_view._get_instructions(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
|
||||
@@ -263,7 +263,9 @@ class TestPausedSandboxResumption:
|
||||
@patch('openhands.app_server.config.get_httpx_client')
|
||||
@patch('openhands.app_server.event_callback.util.ensure_running_sandbox')
|
||||
@patch('openhands.app_server.event_callback.util.get_agent_server_url_from_sandbox')
|
||||
@patch.object(SlackUpdateExistingConversationView, '_get_instructions')
|
||||
@patch.object(
|
||||
SlackUpdateExistingConversationView, '_get_instructions', new_callable=AsyncMock
|
||||
)
|
||||
async def test_paused_sandbox_resumption(
|
||||
self,
|
||||
mock_get_instructions,
|
||||
|
||||
@@ -34,7 +34,6 @@ async def test_send_comment_to_jira_success(mock_jira_manager, processor):
|
||||
)
|
||||
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_manager.send_message = AsyncMock()
|
||||
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_jira('This is a summary.')
|
||||
@@ -130,7 +129,6 @@ async def test_call_sends_summary_to_jira(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_jira_manager.send_message = AsyncMock()
|
||||
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.jira_callback_processor.asyncio.create_task'
|
||||
@@ -200,7 +198,6 @@ async def test_send_comment_to_jira_api_error(mock_jira_manager, processor):
|
||||
)
|
||||
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
|
||||
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action - should not raise exception, but handle it gracefully
|
||||
await processor._send_comment_to_jira('This is a summary.')
|
||||
@@ -328,18 +325,15 @@ async def test_send_comment_to_jira_message_construction(mock_jira_manager, proc
|
||||
)
|
||||
mock_jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_manager.send_message = AsyncMock()
|
||||
mock_outgoing_message = MagicMock()
|
||||
mock_jira_manager.create_outgoing_message.return_value = mock_outgoing_message
|
||||
|
||||
test_message = 'This is a test summary message.'
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_jira(test_message)
|
||||
|
||||
# Assert
|
||||
mock_jira_manager.create_outgoing_message.assert_called_once_with(msg=test_message)
|
||||
# Assert - send_message now receives the string directly
|
||||
mock_jira_manager.send_message.assert_called_once_with(
|
||||
mock_outgoing_message,
|
||||
test_message,
|
||||
issue_key='TEST-123',
|
||||
jira_cloud_id='cloud123',
|
||||
svc_acc_email='service@test.com',
|
||||
@@ -386,7 +380,6 @@ async def test_call_creates_background_task_for_sending(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_jira_manager.send_message = AsyncMock()
|
||||
mock_jira_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.jira_callback_processor.asyncio.create_task'
|
||||
|
||||
@@ -32,7 +32,6 @@ async def test_send_comment_to_jira_dc_success(mock_jira_dc_manager, processor):
|
||||
)
|
||||
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_dc_manager.send_message = AsyncMock()
|
||||
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_jira_dc('This is a summary.')
|
||||
@@ -125,7 +124,6 @@ async def test_call_sends_summary_to_jira_dc(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_jira_dc_manager.send_message = AsyncMock()
|
||||
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.jira_dc_callback_processor.asyncio.create_task'
|
||||
@@ -200,7 +198,6 @@ async def test_send_comment_to_jira_dc_api_error(mock_jira_dc_manager, processor
|
||||
)
|
||||
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_dc_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
|
||||
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action - should not raise exception, but handle it gracefully
|
||||
await processor._send_comment_to_jira_dc('This is a summary.')
|
||||
@@ -328,20 +325,15 @@ async def test_send_comment_to_jira_dc_message_construction(
|
||||
)
|
||||
mock_jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_jira_dc_manager.send_message = AsyncMock()
|
||||
mock_outgoing_message = MagicMock()
|
||||
mock_jira_dc_manager.create_outgoing_message.return_value = mock_outgoing_message
|
||||
|
||||
test_message = 'This is a test summary message.'
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_jira_dc(test_message)
|
||||
|
||||
# Assert
|
||||
mock_jira_dc_manager.create_outgoing_message.assert_called_once_with(
|
||||
msg=test_message
|
||||
)
|
||||
# Assert - send_message now receives the string directly
|
||||
mock_jira_dc_manager.send_message.assert_called_once_with(
|
||||
mock_outgoing_message,
|
||||
test_message,
|
||||
issue_key='TEST-123',
|
||||
base_api_url='https://test-jira-dc.company.com',
|
||||
svc_acc_api_key='decrypted_key',
|
||||
@@ -384,7 +376,6 @@ async def test_call_creates_background_task_for_sending(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_jira_dc_manager.send_message = AsyncMock()
|
||||
mock_jira_dc_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.jira_dc_callback_processor.asyncio.create_task'
|
||||
|
||||
@@ -32,7 +32,6 @@ async def test_send_comment_to_linear_success(mock_linear_manager, processor):
|
||||
)
|
||||
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_linear_manager.send_message = AsyncMock()
|
||||
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_linear('This is a summary.')
|
||||
@@ -125,7 +124,6 @@ async def test_call_sends_summary_to_linear(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_linear_manager.send_message = AsyncMock()
|
||||
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.linear_callback_processor.asyncio.create_task'
|
||||
@@ -200,7 +198,6 @@ async def test_send_comment_to_linear_api_error(mock_linear_manager, processor):
|
||||
)
|
||||
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_linear_manager.send_message = AsyncMock(side_effect=Exception('API Error'))
|
||||
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
# Action - should not raise exception, but handle it gracefully
|
||||
await processor._send_comment_to_linear('This is a summary.')
|
||||
@@ -328,20 +325,15 @@ async def test_send_comment_to_linear_message_construction(
|
||||
)
|
||||
mock_linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
|
||||
mock_linear_manager.send_message = AsyncMock()
|
||||
mock_outgoing_message = MagicMock()
|
||||
mock_linear_manager.create_outgoing_message.return_value = mock_outgoing_message
|
||||
|
||||
test_message = 'This is a test summary message.'
|
||||
|
||||
# Action
|
||||
await processor._send_comment_to_linear(test_message)
|
||||
|
||||
# Assert
|
||||
mock_linear_manager.create_outgoing_message.assert_called_once_with(
|
||||
msg=test_message
|
||||
)
|
||||
# Assert - send_message now receives the string directly
|
||||
mock_linear_manager.send_message.assert_called_once_with(
|
||||
mock_outgoing_message,
|
||||
test_message,
|
||||
'TEST-123', # issue_id
|
||||
'decrypted_key', # api_key
|
||||
)
|
||||
@@ -383,7 +375,6 @@ async def test_call_creates_background_task_for_sending(
|
||||
return_value=mock_workspace
|
||||
)
|
||||
mock_linear_manager.send_message = AsyncMock()
|
||||
mock_linear_manager.create_outgoing_message.return_value = MagicMock()
|
||||
|
||||
with patch(
|
||||
'server.conversation_callback_processor.linear_callback_processor.asyncio.create_task'
|
||||
|
||||
@@ -16,8 +16,15 @@ from storage.device_code import DeviceCode
|
||||
|
||||
@pytest.fixture
|
||||
def mock_device_code_store():
|
||||
"""Mock device code store."""
|
||||
return MagicMock()
|
||||
"""Mock device code store with async methods."""
|
||||
mock = MagicMock()
|
||||
mock.create_device_code = AsyncMock()
|
||||
mock.get_by_device_code = AsyncMock()
|
||||
mock.get_by_user_code = AsyncMock()
|
||||
mock.authorize_device_code = AsyncMock()
|
||||
mock.deny_device_code = AsyncMock()
|
||||
mock.update_poll_time = AsyncMock()
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -54,7 +61,7 @@ class TestDeviceAuthorization:
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
current_interval=5, # Default interval
|
||||
)
|
||||
mock_store.create_device_code.return_value = mock_device
|
||||
mock_store.create_device_code = AsyncMock(return_value=mock_device)
|
||||
|
||||
result = await device_authorization(mock_request)
|
||||
|
||||
@@ -76,7 +83,7 @@ class TestDeviceAuthorization:
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
||||
current_interval=15, # Increased interval from previous rate limiting
|
||||
)
|
||||
mock_store.create_device_code.return_value = mock_device
|
||||
mock_store.create_device_code = AsyncMock(return_value=mock_device)
|
||||
|
||||
result = await device_authorization(mock_request)
|
||||
|
||||
@@ -113,10 +120,10 @@ class TestDeviceToken:
|
||||
mock_device.status = status
|
||||
# Mock rate limiting - return False (not too fast) and default interval
|
||||
mock_device.check_rate_limit.return_value = (False, 5)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
else:
|
||||
mock_store.get_by_device_code.return_value = None
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=None)
|
||||
|
||||
result = await device_token(device_code=device_code)
|
||||
|
||||
@@ -142,12 +149,14 @@ class TestDeviceToken:
|
||||
)
|
||||
# Mock rate limiting - return False (not too fast) and default interval
|
||||
mock_device.check_rate_limit.return_value = (False, 5)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
|
||||
# Mock API key retrieval
|
||||
# Mock API key retrieval - use AsyncMock for async method
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.retrieve_api_key_by_name.return_value = 'test-api-key'
|
||||
mock_api_key_store.retrieve_api_key_by_name = AsyncMock(
|
||||
return_value='test-api-key'
|
||||
)
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -176,7 +185,7 @@ class TestDeviceVerificationAuthenticated:
|
||||
self, mock_store, mock_api_key_class
|
||||
):
|
||||
"""Test verification with invalid device code."""
|
||||
mock_store.get_by_user_code.return_value = None
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(
|
||||
@@ -189,7 +198,7 @@ class TestDeviceVerificationAuthenticated:
|
||||
"""Test verification with already processed device code."""
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = False
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await device_verification_authenticated(
|
||||
@@ -203,8 +212,8 @@ class TestDeviceVerificationAuthenticated:
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.authorize_device_code = AsyncMock(return_value=True)
|
||||
|
||||
# Mock API key store with async create_api_key
|
||||
mock_api_key_store = MagicMock()
|
||||
@@ -248,15 +257,17 @@ class TestDeviceVerificationAuthenticated:
|
||||
mock_device2.is_pending.return_value = True
|
||||
|
||||
# Configure mock store to return appropriate device for each user_code
|
||||
def get_by_user_code_side_effect(user_code):
|
||||
async def get_by_user_code_side_effect(user_code):
|
||||
if user_code == device1_code:
|
||||
return mock_device1
|
||||
elif user_code == device2_code:
|
||||
return mock_device2
|
||||
return None
|
||||
|
||||
mock_store.get_by_user_code.side_effect = get_by_user_code_side_effect
|
||||
mock_store.authorize_device_code.return_value = True
|
||||
mock_store.get_by_user_code = AsyncMock(
|
||||
side_effect=get_by_user_code_side_effect
|
||||
)
|
||||
mock_store.authorize_device_code = AsyncMock(return_value=True)
|
||||
|
||||
# Authenticate first device
|
||||
result1 = await device_verification_authenticated(
|
||||
@@ -305,8 +316,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=None, # First poll
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -336,8 +347,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -367,8 +378,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -399,8 +410,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=last_poll,
|
||||
current_interval=15, # Already increased from previous slow_down
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -430,8 +441,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=last_poll,
|
||||
current_interval=58, # Near maximum of 60
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -457,8 +468,8 @@ class TestDeviceTokenRateLimiting:
|
||||
last_poll_time=last_poll,
|
||||
current_interval=5,
|
||||
)
|
||||
mock_store.get_by_device_code.return_value = mock_device
|
||||
mock_store.update_poll_time.return_value = True
|
||||
mock_store.get_by_device_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.update_poll_time = AsyncMock(return_value=True)
|
||||
|
||||
device_code = 'test_device_code'
|
||||
result = await device_token(device_code=device_code)
|
||||
@@ -487,8 +498,10 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = False # Authorization fails
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.authorize_device_code = AsyncMock(
|
||||
return_value=False
|
||||
) # Authorization fails
|
||||
|
||||
# Mock API key store with async create_api_key
|
||||
mock_api_key_store = MagicMock()
|
||||
@@ -519,9 +532,11 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.deny_device_code.return_value = True # Cleanup succeeds
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.authorize_device_code = AsyncMock(
|
||||
return_value=True
|
||||
) # Authorization succeeds
|
||||
mock_store.deny_device_code = AsyncMock(return_value=True) # Cleanup succeeds
|
||||
|
||||
# Mock API key store to fail on creation (async)
|
||||
mock_api_key_store = MagicMock()
|
||||
@@ -559,10 +574,12 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.deny_device_code.side_effect = Exception(
|
||||
'Cleanup failed'
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.authorize_device_code = AsyncMock(
|
||||
return_value=True
|
||||
) # Authorization succeeds
|
||||
mock_store.deny_device_code = AsyncMock(
|
||||
side_effect=Exception('Cleanup failed')
|
||||
) # Cleanup fails
|
||||
|
||||
# Mock API key store to fail on creation (async)
|
||||
@@ -595,8 +612,11 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
# Mock device code
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.get_by_user_code = AsyncMock(return_value=mock_device)
|
||||
mock_store.authorize_device_code = AsyncMock(
|
||||
return_value=True
|
||||
) # Authorization succeeds
|
||||
mock_store.deny_device_code = AsyncMock()
|
||||
|
||||
# Mock API key store with async create_api_key
|
||||
mock_api_key_store = MagicMock()
|
||||
|
||||
@@ -11,41 +11,37 @@ import httpx
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.testclient import TestClient
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
InsufficientPermissionError,
|
||||
InvalidRoleError,
|
||||
LastOwnerError,
|
||||
LiteLLMIntegrationError,
|
||||
MeResponse,
|
||||
OrgAppSettingsResponse,
|
||||
OrgAppSettingsUpdate,
|
||||
OrgAuthorizationError,
|
||||
OrgDatabaseError,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
OrgMemberUpdate,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrphanedUserError,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from server.routes.orgs import (
|
||||
get_me,
|
||||
get_org_members,
|
||||
org_router,
|
||||
remove_org_member,
|
||||
update_org_member,
|
||||
)
|
||||
from storage.org import Org
|
||||
|
||||
# Mock database before imports
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
InsufficientPermissionError,
|
||||
InvalidRoleError,
|
||||
LastOwnerError,
|
||||
LiteLLMIntegrationError,
|
||||
MeResponse,
|
||||
OrgAuthorizationError,
|
||||
OrgDatabaseError,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
OrgMemberUpdate,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrphanedUserError,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from server.routes.orgs import (
|
||||
get_me,
|
||||
get_org_members,
|
||||
org_router,
|
||||
remove_org_member,
|
||||
update_org_member,
|
||||
)
|
||||
from storage.org import Org
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
# Test user ID constant (must be a valid UUID string)
|
||||
TEST_USER_ID = str(uuid.uuid4())
|
||||
@@ -3424,3 +3420,421 @@ async def test_switch_org_database_error(mock_app_with_get_user_id):
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert 'Failed to switch organization' in response.json()['detail']
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for App Settings Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_member_role():
|
||||
"""Create a mock member role for authorization tests."""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
return mock_role
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_app_settings_success(
|
||||
mock_app_with_get_user_id, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: Authenticated user with MANAGE_APPLICATION_SETTINGS permission
|
||||
WHEN: GET /api/organizations/app is called
|
||||
THEN: App settings are returned with 200 status
|
||||
"""
|
||||
# Arrange
|
||||
mock_response = OrgAppSettingsResponse(
|
||||
enable_proactive_conversation_starters=True,
|
||||
enable_solvability_analysis=False,
|
||||
max_budget_per_task=10.0,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgAppSettingsService.get_org_app_settings',
|
||||
AsyncMock(return_value=mock_response),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/organizations/app')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response_data = response.json()
|
||||
assert response_data['enable_proactive_conversation_starters'] is True
|
||||
assert response_data['enable_solvability_analysis'] is False
|
||||
assert response_data['max_budget_per_task'] == 10.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_app_settings_with_null_values(
|
||||
mock_app_with_get_user_id, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: Organization has null app settings values
|
||||
WHEN: GET /api/organizations/app is called
|
||||
THEN: Default values are returned where applicable
|
||||
"""
|
||||
# Arrange
|
||||
# OrgAppSettingsResponse.from_org() handles defaults, so we test the response model
|
||||
mock_response = OrgAppSettingsResponse(
|
||||
enable_proactive_conversation_starters=True, # Default when None in Org
|
||||
enable_solvability_analysis=None,
|
||||
max_budget_per_task=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgAppSettingsService.get_org_app_settings',
|
||||
AsyncMock(return_value=mock_response),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/organizations/app')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response_data = response.json()
|
||||
# enable_proactive_conversation_starters defaults to True when None
|
||||
assert response_data['enable_proactive_conversation_starters'] is True
|
||||
assert response_data['enable_solvability_analysis'] is None
|
||||
assert response_data['max_budget_per_task'] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_app_settings_not_found(
|
||||
mock_app_with_get_user_id, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: User has no current organization
|
||||
WHEN: GET /api/organizations/app is called
|
||||
THEN: 404 Not Found error is returned
|
||||
"""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgAppSettingsService.get_org_app_settings',
|
||||
AsyncMock(side_effect=OrgNotFoundError('current')),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/organizations/app')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert 'not found' in response.json()['detail'].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_app_settings_user_not_member(mock_app_with_get_user_id):
|
||||
"""
|
||||
GIVEN: User is not a member of any organization
|
||||
WHEN: GET /api/organizations/app is called
|
||||
THEN: 403 Forbidden error is returned
|
||||
"""
|
||||
# Arrange - user has no role (not a member)
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/organizations/app')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert 'not a member' in response.json()['detail'].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_success(
|
||||
mock_app_with_get_user_id, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: Valid update data and authenticated user
|
||||
WHEN: POST /api/organizations/app is called
|
||||
THEN: Updated app settings are returned with 200 status
|
||||
"""
|
||||
# Arrange
|
||||
mock_response = OrgAppSettingsResponse(
|
||||
enable_proactive_conversation_starters=False,
|
||||
enable_solvability_analysis=True,
|
||||
max_budget_per_task=25.0,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgAppSettingsService.update_org_app_settings',
|
||||
AsyncMock(return_value=mock_response),
|
||||
) as mock_update,
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.post(
|
||||
'/api/organizations/app',
|
||||
json={
|
||||
'enable_proactive_conversation_starters': False,
|
||||
'enable_solvability_analysis': True,
|
||||
'max_budget_per_task': 25.0,
|
||||
},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response_data = response.json()
|
||||
assert response_data['enable_proactive_conversation_starters'] is False
|
||||
assert response_data['enable_solvability_analysis'] is True
|
||||
assert response_data['max_budget_per_task'] == 25.0
|
||||
mock_update.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_partial_update(
|
||||
mock_app_with_get_user_id, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: Partial update data (only some fields)
|
||||
WHEN: POST /api/organizations/app is called
|
||||
THEN: Only specified fields are updated
|
||||
"""
|
||||
# Arrange
|
||||
mock_response = OrgAppSettingsResponse(
|
||||
enable_proactive_conversation_starters=False,
|
||||
enable_solvability_analysis=True,
|
||||
max_budget_per_task=10.0, # Unchanged
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgAppSettingsService.update_org_app_settings',
|
||||
AsyncMock(return_value=mock_response),
|
||||
) as mock_update,
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act - only updating one field
|
||||
response = client.post(
|
||||
'/api/organizations/app',
|
||||
json={'enable_proactive_conversation_starters': False},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
mock_update.assert_called_once()
|
||||
# Verify the update data only contains the specified field
|
||||
call_args = mock_update.call_args
|
||||
update_data = call_args[0][0] # First positional argument (update_data)
|
||||
assert isinstance(update_data, OrgAppSettingsUpdate)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_set_null(
|
||||
mock_app_with_get_user_id, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: Request to set max_budget_per_task to null
|
||||
WHEN: POST /api/organizations/app is called
|
||||
THEN: The field is set to null successfully
|
||||
"""
|
||||
# Arrange
|
||||
mock_response = OrgAppSettingsResponse(
|
||||
enable_proactive_conversation_starters=True,
|
||||
enable_solvability_analysis=True,
|
||||
max_budget_per_task=None,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgAppSettingsService.update_org_app_settings',
|
||||
AsyncMock(return_value=mock_response),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act - explicitly setting max_budget_per_task to null
|
||||
response = client.post(
|
||||
'/api/organizations/app',
|
||||
json={'max_budget_per_task': None},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response_data = response.json()
|
||||
assert response_data['max_budget_per_task'] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_invalid_max_budget(
|
||||
mock_app_with_get_user_id, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: Invalid max_budget_per_task value (zero or negative)
|
||||
WHEN: POST /api/organizations/app is called
|
||||
THEN: 422 Validation error is returned
|
||||
"""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act - negative value
|
||||
response = client.post(
|
||||
'/api/organizations/app',
|
||||
json={'max_budget_per_task': -5.0},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_zero_max_budget(
|
||||
mock_app_with_get_user_id, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: max_budget_per_task is set to zero
|
||||
WHEN: POST /api/organizations/app is called
|
||||
THEN: 422 Validation error is returned (must be greater than 0)
|
||||
"""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act - zero value
|
||||
response = client.post(
|
||||
'/api/organizations/app',
|
||||
json={'max_budget_per_task': 0},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_not_found(
|
||||
mock_app_with_get_user_id, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: User has no current organization
|
||||
WHEN: POST /api/organizations/app is called
|
||||
THEN: 404 Not Found error is returned
|
||||
"""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgAppSettingsService.update_org_app_settings',
|
||||
AsyncMock(side_effect=OrgNotFoundError('current')),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.post(
|
||||
'/api/organizations/app',
|
||||
json={'enable_proactive_conversation_starters': False},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert 'not found' in response.json()['detail'].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_database_error(
|
||||
mock_app_with_get_user_id, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: Database update fails
|
||||
WHEN: POST /api/organizations/app is called
|
||||
THEN: 500 Internal Server Error is returned
|
||||
"""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgAppSettingsService.update_org_app_settings',
|
||||
AsyncMock(side_effect=Exception('Database connection failed')),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.post(
|
||||
'/api/organizations/app',
|
||||
json={'enable_proactive_conversation_starters': False},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert 'unexpected error' in response.json()['detail'].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_user_not_member(mock_app_with_get_user_id):
|
||||
"""
|
||||
GIVEN: User is not a member of any organization
|
||||
WHEN: POST /api/organizations/app is called
|
||||
THEN: 403 Forbidden error is returned
|
||||
"""
|
||||
# Arrange - user has no role (not a member)
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.post(
|
||||
'/api/organizations/app',
|
||||
json={'enable_proactive_conversation_starters': False},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert 'not a member' in response.json()['detail'].lower()
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Unit tests for OrgAppSettingsService.
|
||||
|
||||
Tests the service layer for organization app settings operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from server.routes.org_models import (
|
||||
OrgAppSettingsResponse,
|
||||
OrgAppSettingsUpdate,
|
||||
OrgNotFoundError,
|
||||
)
|
||||
from server.services.org_app_settings_service import OrgAppSettingsService
|
||||
from storage.org import Org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id():
|
||||
"""Create a test user ID."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_org():
|
||||
"""Create a mock organization with app settings."""
|
||||
org = MagicMock(spec=Org)
|
||||
org.id = uuid.uuid4()
|
||||
org.enable_proactive_conversation_starters = True
|
||||
org.enable_solvability_analysis = False
|
||||
org.max_budget_per_task = 25.0
|
||||
return org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_store():
|
||||
"""Create a mock OrgAppSettingsStore."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_context(user_id):
|
||||
"""Create a mock UserContext that returns the user_id."""
|
||||
context = MagicMock()
|
||||
context.get_user_id = AsyncMock(return_value=user_id)
|
||||
return context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_app_settings_success(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user's current organization exists
|
||||
WHEN: get_org_app_settings is called
|
||||
THEN: OrgAppSettingsResponse is returned with correct data
|
||||
"""
|
||||
# Arrange
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.get_org_app_settings()
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgAppSettingsResponse)
|
||||
assert result.enable_proactive_conversation_starters is True
|
||||
assert result.enable_solvability_analysis is False
|
||||
assert result.max_budget_per_task == 25.0
|
||||
mock_store.get_current_org_by_user_id.assert_called_once_with(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_app_settings_org_not_found(
|
||||
user_id, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user has no current organization
|
||||
WHEN: get_org_app_settings is called
|
||||
THEN: OrgNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
|
||||
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
await service.get_org_app_settings()
|
||||
|
||||
assert 'current' in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_success(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user's current organization exists
|
||||
WHEN: update_org_app_settings is called with new values
|
||||
THEN: OrgAppSettingsResponse is returned with updated data
|
||||
"""
|
||||
# Arrange
|
||||
mock_org.enable_proactive_conversation_starters = False
|
||||
mock_org.max_budget_per_task = 50.0
|
||||
|
||||
update_data = OrgAppSettingsUpdate(
|
||||
enable_proactive_conversation_starters=False,
|
||||
max_budget_per_task=50.0,
|
||||
)
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
mock_store.update_org_app_settings = AsyncMock(return_value=mock_org)
|
||||
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.update_org_app_settings(update_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgAppSettingsResponse)
|
||||
assert result.enable_proactive_conversation_starters is False
|
||||
assert result.max_budget_per_task == 50.0
|
||||
mock_store.update_org_app_settings.assert_called_once_with(
|
||||
org_id=mock_org.id, update_data=update_data
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_no_changes(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user's current organization exists
|
||||
WHEN: update_org_app_settings is called with no fields
|
||||
THEN: Current settings are returned without calling update
|
||||
"""
|
||||
# Arrange
|
||||
update_data = OrgAppSettingsUpdate() # No fields set
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
mock_store.update_org_app_settings = AsyncMock()
|
||||
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.update_org_app_settings(update_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgAppSettingsResponse)
|
||||
mock_store.get_current_org_by_user_id.assert_called_once_with(user_id)
|
||||
mock_store.update_org_app_settings.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_org_not_found(
|
||||
user_id, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user has no current organization
|
||||
WHEN: update_org_app_settings is called
|
||||
THEN: OrgNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
update_data = OrgAppSettingsUpdate(enable_proactive_conversation_starters=False)
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
|
||||
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
await service.update_org_app_settings(update_data)
|
||||
|
||||
assert 'current' in str(exc_info.value)
|
||||
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Unit tests for OrgLLMSettingsService.
|
||||
|
||||
Tests the service layer for organization LLM settings operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from server.routes.org_models import (
|
||||
OrgLLMSettingsResponse,
|
||||
OrgLLMSettingsUpdate,
|
||||
OrgNotFoundError,
|
||||
)
|
||||
from server.services.org_llm_settings_service import OrgLLMSettingsService
|
||||
from storage.org import Org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id():
|
||||
"""Create a test user ID."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def org_id():
|
||||
"""Create a test org ID."""
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_org(org_id):
|
||||
"""Create a mock organization with LLM settings."""
|
||||
org = MagicMock(spec=Org)
|
||||
org.id = org_id
|
||||
org.default_llm_model = 'claude-3'
|
||||
org.default_llm_base_url = 'https://api.anthropic.com'
|
||||
org.search_api_key = None
|
||||
org.agent = 'CodeActAgent'
|
||||
org.confirmation_mode = True
|
||||
org.security_analyzer = None
|
||||
org.enable_default_condenser = True
|
||||
org.condenser_max_size = None
|
||||
org.default_max_iterations = 50
|
||||
return org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_store():
|
||||
"""Create a mock OrgLLMSettingsStore."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_context(user_id):
|
||||
"""Create a mock UserContext that returns the user_id."""
|
||||
context = MagicMock()
|
||||
context.get_user_id = AsyncMock(return_value=user_id)
|
||||
return context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_llm_settings_success(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user with a current organization
|
||||
WHEN: get_org_llm_settings is called
|
||||
THEN: OrgLLMSettingsResponse is returned with correct data
|
||||
"""
|
||||
# Arrange
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.get_org_llm_settings()
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgLLMSettingsResponse)
|
||||
assert result.default_llm_model == 'claude-3'
|
||||
assert result.agent == 'CodeActAgent'
|
||||
mock_store.get_current_org_by_user_id.assert_called_once_with(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_llm_settings_user_not_authenticated(mock_store):
|
||||
"""
|
||||
GIVEN: A user is not authenticated
|
||||
WHEN: get_org_llm_settings is called
|
||||
THEN: ValueError is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_user_context = MagicMock()
|
||||
mock_user_context.get_user_id = AsyncMock(return_value=None)
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await service.get_org_llm_settings()
|
||||
|
||||
assert 'not authenticated' in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_llm_settings_org_not_found(
|
||||
user_id, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user has no current organization
|
||||
WHEN: get_org_llm_settings is called
|
||||
THEN: OrgNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
await service.get_org_llm_settings()
|
||||
|
||||
assert 'No current organization' in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_success(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user with a current organization
|
||||
WHEN: update_org_llm_settings is called with new values
|
||||
THEN: OrgLLMSettingsResponse is returned with updated data
|
||||
"""
|
||||
# Arrange
|
||||
updated_org = MagicMock(spec=Org)
|
||||
updated_org.id = mock_org.id
|
||||
updated_org.default_llm_model = 'new-model'
|
||||
updated_org.default_llm_base_url = None
|
||||
updated_org.search_api_key = None
|
||||
updated_org.agent = 'CodeActAgent'
|
||||
updated_org.confirmation_mode = False
|
||||
updated_org.security_analyzer = None
|
||||
updated_org.enable_default_condenser = True
|
||||
updated_org.condenser_max_size = None
|
||||
updated_org.default_max_iterations = 100
|
||||
|
||||
update_data = OrgLLMSettingsUpdate(
|
||||
default_llm_model='new-model',
|
||||
confirmation_mode=False,
|
||||
default_max_iterations=100,
|
||||
)
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
mock_store.update_org_llm_settings = AsyncMock(return_value=updated_org)
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.update_org_llm_settings(update_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgLLMSettingsResponse)
|
||||
assert result.default_llm_model == 'new-model'
|
||||
assert result.confirmation_mode is False
|
||||
assert result.default_max_iterations == 100
|
||||
mock_store.update_org_llm_settings.assert_called_once_with(
|
||||
org_id=mock_org.id,
|
||||
update_data=update_data,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_no_changes(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user with a current organization
|
||||
WHEN: update_org_llm_settings is called with no fields
|
||||
THEN: Current settings are returned without calling update
|
||||
"""
|
||||
# Arrange
|
||||
update_data = OrgLLMSettingsUpdate() # No fields set
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
mock_store.update_org_llm_settings = AsyncMock()
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.update_org_llm_settings(update_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgLLMSettingsResponse)
|
||||
assert result.default_llm_model == 'claude-3'
|
||||
mock_store.update_org_llm_settings.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_org_not_found(
|
||||
user_id, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user has no current organization
|
||||
WHEN: update_org_llm_settings is called
|
||||
THEN: OrgNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
update_data = OrgLLMSettingsUpdate(default_llm_model='new-model')
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
await service.update_org_llm_settings(update_data)
|
||||
|
||||
assert 'No current organization' in str(exc_info.value)
|
||||
@@ -399,3 +399,135 @@ class TestUpdateActiveWorkingSeconds:
|
||||
assert conversation_work.seconds == 23.0
|
||||
assert conversation_work.conversation_id == conversation_id
|
||||
assert conversation_work.user_id == user_id
|
||||
|
||||
|
||||
class TestInvokeConversationCallbacks:
|
||||
"""Tests for invoke_conversation_callbacks function.
|
||||
|
||||
This function uses async database sessions (a_session_maker) to query
|
||||
and invoke callbacks for a conversation.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_observation(self):
|
||||
"""Create a mock AgentStateChangedObservation."""
|
||||
|
||||
observation = Mock(spec=AgentStateChangedObservation)
|
||||
observation.agent_state = AgentState.FINISHED
|
||||
return observation
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_async_session(self):
|
||||
"""Factory to create properly mocked async session context manager."""
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
def _create(callbacks_list):
|
||||
mock_session = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.scalars.return_value.all.return_value = callbacks_list
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock(return_value=None)
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_context_manager():
|
||||
yield mock_session
|
||||
|
||||
return mock_context_manager, mock_session
|
||||
|
||||
return _create
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_callbacks_with_active_callbacks(
|
||||
self, mock_observation, create_mock_async_session
|
||||
):
|
||||
"""Test that active callbacks are invoked successfully."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
# Arrange
|
||||
conversation_id = 'test_conversation_callbacks'
|
||||
mock_processor = AsyncMock(return_value=None)
|
||||
|
||||
# Create a mock callback
|
||||
mock_callback = Mock()
|
||||
mock_callback.id = 1
|
||||
mock_callback.processor_type = 'test_processor'
|
||||
mock_callback.get_processor.return_value = mock_processor
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session([mock_callback])
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.a_session_maker',
|
||||
mock_context_manager,
|
||||
):
|
||||
from server.utils.conversation_callback_utils import (
|
||||
invoke_conversation_callbacks,
|
||||
)
|
||||
|
||||
await invoke_conversation_callbacks(conversation_id, mock_observation)
|
||||
|
||||
# Assert
|
||||
mock_callback.get_processor.assert_called_once()
|
||||
mock_processor.assert_called_once_with(mock_callback, mock_observation)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_callbacks_with_no_active_callbacks(
|
||||
self, mock_observation, create_mock_async_session
|
||||
):
|
||||
"""Test behavior when no active callbacks exist."""
|
||||
# Arrange
|
||||
conversation_id = 'test_no_callbacks'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session([])
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.a_session_maker',
|
||||
mock_context_manager,
|
||||
):
|
||||
from server.utils.conversation_callback_utils import (
|
||||
invoke_conversation_callbacks,
|
||||
)
|
||||
|
||||
await invoke_conversation_callbacks(conversation_id, mock_observation)
|
||||
|
||||
# Assert - should complete without errors
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_callbacks_handles_processor_exception(
|
||||
self, mock_observation, create_mock_async_session
|
||||
):
|
||||
"""Test that processor exceptions are caught and callback status is updated."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
# Arrange
|
||||
conversation_id = 'test_callback_error'
|
||||
mock_processor = AsyncMock(side_effect=Exception('Processor error'))
|
||||
|
||||
mock_callback = Mock()
|
||||
mock_callback.id = 1
|
||||
mock_callback.processor_type = 'failing_processor'
|
||||
mock_callback.get_processor.return_value = mock_processor
|
||||
mock_callback.status = 'active'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session([mock_callback])
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'server.utils.conversation_callback_utils.a_session_maker',
|
||||
mock_context_manager,
|
||||
), patch('server.utils.conversation_callback_utils.logger') as mock_logger:
|
||||
from server.utils.conversation_callback_utils import (
|
||||
invoke_conversation_callbacks,
|
||||
)
|
||||
from storage.conversation_callback import CallbackStatus
|
||||
|
||||
await invoke_conversation_callbacks(conversation_id, mock_observation)
|
||||
|
||||
# Assert - callback status should be set to ERROR
|
||||
assert mock_callback.status == CallbackStatus.ERROR
|
||||
mock_logger.error.assert_called_once()
|
||||
error_call = mock_logger.error.call_args
|
||||
assert error_call[0][0] == 'callback_invocation_failed'
|
||||
|
||||
@@ -1,127 +1,127 @@
|
||||
"""Unit tests for AuthTokenStore."""
|
||||
"""Unit tests for AuthTokenStore using SQLite in-memory database."""
|
||||
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from server.auth.auth_error import TokenRefreshError
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.auth_token_store import (
|
||||
ACCESS_TOKEN_EXPIRY_BUFFER,
|
||||
LOCK_TIMEOUT_SECONDS,
|
||||
AuthTokenStore,
|
||||
)
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.base import Base
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
|
||||
def create_mock_session():
|
||||
"""Create a mock async session with properly configured context managers."""
|
||||
session = AsyncMock()
|
||||
|
||||
# Create async context manager for begin()
|
||||
@asynccontextmanager
|
||||
async def begin_context():
|
||||
yield
|
||||
|
||||
session.begin = begin_context
|
||||
return session
|
||||
|
||||
|
||||
def create_mock_session_maker(mock_session):
|
||||
"""Create a mock async session maker."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_context():
|
||||
yield mock_session
|
||||
|
||||
# Return a callable that returns the context manager
|
||||
return lambda: session_context()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Create mock async session."""
|
||||
return create_mock_session()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Create mock async session maker."""
|
||||
return create_mock_session_maker(mock_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_token_store(mock_session_maker):
|
||||
"""Create AuthTokenStore instance with mocked session maker."""
|
||||
return AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker bound to the async engine."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
# Create all tables
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
return async_session_maker
|
||||
|
||||
|
||||
class TestIsTokenExpired:
|
||||
"""Tests for _is_token_expired method."""
|
||||
|
||||
def test_both_tokens_valid(self, auth_token_store):
|
||||
def test_both_tokens_valid(self):
|
||||
"""Test when both tokens are valid (not expired)."""
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
refresh_expires = current_time + 1000
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expired, refresh_expired = store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is False
|
||||
|
||||
def test_access_token_expired(self, auth_token_store):
|
||||
def test_access_token_expired(self):
|
||||
"""Test when access token is expired but within buffer."""
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
current_time = int(time.time())
|
||||
# Access token expires within buffer period
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER - 100
|
||||
refresh_expires = current_time + 10000
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expired, refresh_expired = store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is True
|
||||
assert refresh_expired is False
|
||||
|
||||
def test_refresh_token_expired(self, auth_token_store):
|
||||
def test_refresh_token_expired(self):
|
||||
"""Test when refresh token is expired."""
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
refresh_expires = current_time - 100 # Already expired
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expired, refresh_expired = store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is True
|
||||
|
||||
def test_both_tokens_expired(self, auth_token_store):
|
||||
def test_both_tokens_expired(self):
|
||||
"""Test when both tokens are expired."""
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time - 100
|
||||
refresh_expires = current_time - 100
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expired, refresh_expired = store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is True
|
||||
assert refresh_expired is True
|
||||
|
||||
def test_zero_expiration_treated_as_never_expires(self, auth_token_store):
|
||||
def test_zero_expiration_treated_as_never_expires(self):
|
||||
"""Test that 0 expiration time is treated as never expires."""
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(0, 0)
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
access_expired, refresh_expired = store._is_token_expired(0, 0)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is False
|
||||
@@ -131,427 +131,188 @@ class TestLoadTokensFastPath:
|
||||
"""Tests for load_tokens fast path (no lock needed)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_token_not_found(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
async def test_fast_path_token_not_found(self, async_session_maker):
|
||||
"""Test fast path returns None when no token record exists."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
result = await auth_token_store.load_tokens()
|
||||
result = await store.load_tokens()
|
||||
|
||||
assert result is None
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_valid_token_no_refresh_needed(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
async def test_fast_path_valid_token_no_refresh_needed(self, async_session_maker):
|
||||
"""Test fast path returns tokens when they are still valid."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'valid-access-token'
|
||||
mock_token.refresh_token = 'valid-refresh-token'
|
||||
mock_token.access_token_expires_at = (
|
||||
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
)
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
# First, store a valid token in the database
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
result = await auth_token_store.load_tokens()
|
||||
await store.store_tokens(
|
||||
access_token='valid-access-token',
|
||||
refresh_token='valid-refresh-token',
|
||||
access_token_expires_at=current_time
|
||||
+ ACCESS_TOKEN_EXPIRY_BUFFER
|
||||
+ 1000,
|
||||
refresh_token_expires_at=current_time + 10000,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'valid-access-token'
|
||||
assert result['refresh_token'] == 'valid-refresh-token'
|
||||
# Now load tokens - should return valid tokens without refresh
|
||||
result = await store.load_tokens()
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'valid-access-token'
|
||||
assert result['refresh_token'] == 'valid-refresh-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_no_refresh_callback_provided(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
async def test_fast_path_no_refresh_callback_provided(self, async_session_maker):
|
||||
"""Test fast path returns existing tokens when no refresh callback is provided."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'expired-access-token'
|
||||
mock_token.refresh_token = 'valid-refresh-token'
|
||||
# Expired access token
|
||||
mock_token.access_token_expires_at = current_time - 100
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
# Store expired access token
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
result = await auth_token_store.load_tokens(check_expiration_and_refresh=None)
|
||||
await store.store_tokens(
|
||||
access_token='expired-access-token',
|
||||
refresh_token='valid-refresh-token',
|
||||
access_token_expires_at=current_time - 100, # Expired
|
||||
refresh_token_expires_at=current_time + 10000,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'expired-access-token'
|
||||
# Load without refresh callback - should still return tokens
|
||||
result = await store.load_tokens(check_expiration_and_refresh=None)
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'expired-access-token'
|
||||
|
||||
|
||||
class TestLoadTokensSlowPath:
|
||||
"""Tests for load_tokens slow path (lock required for refresh)."""
|
||||
"""Tests for load_tokens slow path (lock required for refresh).
|
||||
|
||||
Note: These tests require PostgreSQL's lock_timeout feature which is not
|
||||
available in SQLite. The slow path tests are skipped when using SQLite.
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason='SQLite does not support PostgreSQL lock_timeout syntax')
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_successful_refresh(self):
|
||||
async def test_slow_path_successful_refresh(self, async_session_maker):
|
||||
"""Test slow path successfully refreshes expired tokens."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
pass
|
||||
|
||||
# First call (fast path) - returns expired token
|
||||
# Second call (slow path) - returns same token for update
|
||||
expired_token = MagicMock()
|
||||
expired_token.id = 1
|
||||
expired_token.access_token = 'expired-access-token'
|
||||
expired_token.refresh_token = 'valid-refresh-token'
|
||||
expired_token.access_token_expires_at = current_time - 100 # Expired
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-access-token',
|
||||
'refresh_token': 'new-refresh-token',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'new-access-token'
|
||||
assert result['refresh_token'] == 'new-refresh-token'
|
||||
@pytest.mark.skip(reason='SQLite does not support PostgreSQL lock_timeout syntax')
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_callback_returns_none(self, async_session_maker):
|
||||
"""Test behavior when refresh callback returns None (no refresh performed)."""
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_double_check_avoids_refresh(self):
|
||||
"""Test double-check locking: token was refreshed by another request."""
|
||||
async def test_slow_path_double_check_avoids_refresh(self, async_session_maker):
|
||||
"""Test double-check pattern avoids unnecessary refresh."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# Simulate scenario:
|
||||
# 1. Fast path sees expired token
|
||||
# 2. While waiting for lock, another request refreshes
|
||||
# 3. Slow path sees fresh token, skips refresh
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def create_token():
|
||||
call_count[0] += 1
|
||||
token = MagicMock()
|
||||
token.id = 1
|
||||
token.access_token = 'fresh-access-token'
|
||||
token.refresh_token = 'fresh-refresh-token'
|
||||
if call_count[0] == 1:
|
||||
# First call (fast path) - expired
|
||||
token.access_token_expires_at = current_time - 100
|
||||
else:
|
||||
# Second call (slow path) - already refreshed
|
||||
token.access_token_expires_at = (
|
||||
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
)
|
||||
token.refresh_token_expires_at = current_time + 86400
|
||||
return token
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.side_effect = (
|
||||
lambda: create_token()
|
||||
)
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
refresh_called = [False]
|
||||
|
||||
async def mock_refresh(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int]:
|
||||
refresh_called[0] = True
|
||||
return {
|
||||
'access_token': 'should-not-be-used',
|
||||
'refresh_token': 'should-not-be-used',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
# The refresh callback should not be called because double-check
|
||||
# found the token was already refreshed
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'fresh-access-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_token_not_found_after_lock(self):
|
||||
"""Test slow path returns None if token record disappears after lock."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# First call (fast path) - token exists but expired
|
||||
# Second call (slow path with lock) - token no longer exists
|
||||
call_count = [0]
|
||||
|
||||
def get_token():
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
token = MagicMock()
|
||||
token.access_token_expires_at = current_time - 100 # Expired
|
||||
token.refresh_token_expires_at = current_time + 10000
|
||||
return token
|
||||
return None
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.side_effect = get_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLoadTokensLockTimeout:
|
||||
"""Tests for lock timeout handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_timeout_raises_token_refresh_error(self):
|
||||
"""Test that lock timeout raises TokenRefreshError."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# First call (fast path) - returns expired token
|
||||
expired_token = MagicMock()
|
||||
expired_token.access_token_expires_at = current_time - 100
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
|
||||
# First execute for fast path succeeds
|
||||
# Second execute (for slow path) raises OperationalError
|
||||
call_count = [0]
|
||||
|
||||
async def execute_side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 1:
|
||||
return mock_result
|
||||
# Simulate lock timeout
|
||||
raise OperationalError(
|
||||
'canceling statement due to lock timeout', None, None
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
mock_session.execute = execute_side_effect
|
||||
# Store a token that will be valid when second check happens
|
||||
await store.store_tokens(
|
||||
access_token='original-access-token',
|
||||
refresh_token='valid-refresh-token',
|
||||
access_token_expires_at=current_time
|
||||
+ ACCESS_TOKEN_EXPIRY_BUFFER
|
||||
+ 1000,
|
||||
refresh_token_expires_at=current_time + 10000,
|
||||
)
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
# Load with refresh callback - should NOT refresh since token is valid
|
||||
result = await store.load_tokens()
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
with pytest.raises(TokenRefreshError) as exc_info:
|
||||
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert 'lock timeout' in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_timeout_preserves_original_exception(self):
|
||||
"""Test that TokenRefreshError preserves the original OperationalError."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
expired_token = MagicMock()
|
||||
expired_token.access_token_expires_at = current_time - 100
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
|
||||
original_error = OperationalError(
|
||||
'canceling statement due to lock timeout', None, None
|
||||
)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
async def execute_side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 1:
|
||||
return mock_result
|
||||
raise original_error
|
||||
|
||||
mock_session.execute = execute_side_effect
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
with pytest.raises(TokenRefreshError) as exc_info:
|
||||
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
# Verify the original exception is chained
|
||||
assert exc_info.value.__cause__ is original_error
|
||||
|
||||
|
||||
class TestLoadTokensRefreshCallbackBehavior:
|
||||
"""Tests for refresh callback return values."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_callback_returns_none(self):
|
||||
"""Test behavior when refresh callback returns None (no refresh performed)."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
expired_token = MagicMock()
|
||||
expired_token.id = 1
|
||||
expired_token.access_token = 'old-access-token'
|
||||
expired_token.refresh_token = 'old-refresh-token'
|
||||
expired_token.access_token_expires_at = current_time - 100 # Expired
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh_returns_none(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int] | None:
|
||||
return None
|
||||
|
||||
result = await auth_store.load_tokens(
|
||||
check_expiration_and_refresh=mock_refresh_returns_none
|
||||
)
|
||||
|
||||
# Should return the old tokens when refresh returns None
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'old-access-token'
|
||||
assert result['refresh_token'] == 'old-refresh-token'
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'original-access-token'
|
||||
|
||||
|
||||
class TestStoreTokens:
|
||||
"""Tests for store_tokens method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_tokens_creates_new_record(self):
|
||||
async def test_store_tokens_creates_new_record(self, async_session_maker):
|
||||
"""Test storing tokens when no existing record."""
|
||||
mock_session = create_mock_session()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
await store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
await auth_store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
# Verify the token was stored
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == 'test-user-123',
|
||||
AuthTokens.identity_provider == ProviderType.GITHUB.value,
|
||||
)
|
||||
)
|
||||
token_record = result.scalars().first()
|
||||
assert token_record is not None
|
||||
assert token_record.access_token == 'new-access-token'
|
||||
assert token_record.refresh_token == 'new-refresh-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_tokens_updates_existing_record(self):
|
||||
async def test_store_tokens_updates_existing_record(self, async_session_maker):
|
||||
"""Test storing tokens updates existing record."""
|
||||
mock_session = create_mock_session()
|
||||
existing_token = MagicMock()
|
||||
existing_token.access_token = 'old-access'
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = existing_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
# First, create a token record
|
||||
await store.store_tokens(
|
||||
access_token='old-access-token',
|
||||
refresh_token='old-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
# Now update it
|
||||
await store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567891,
|
||||
refresh_token_expires_at=1234657891,
|
||||
)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
await auth_store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
|
||||
assert existing_token.access_token == 'new-access-token'
|
||||
assert existing_token.refresh_token == 'new-refresh-token'
|
||||
# Verify the token was updated
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == 'test-user-123',
|
||||
AuthTokens.identity_provider == ProviderType.GITHUB.value,
|
||||
)
|
||||
)
|
||||
token_record = result.scalars().first()
|
||||
assert token_record is not None
|
||||
assert token_record.access_token == 'new-access-token'
|
||||
assert token_record.refresh_token == 'new-refresh-token'
|
||||
|
||||
|
||||
class TestIsAccessTokenValid:
|
||||
@@ -559,80 +320,93 @@ class TestIsAccessTokenValid:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_false_when_no_tokens(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
self, async_session_maker
|
||||
):
|
||||
"""Test returns False when no tokens found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
result = await store.is_access_token_valid()
|
||||
|
||||
assert result is False
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_true_for_valid_token(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
self, async_session_maker
|
||||
):
|
||||
"""Test returns True when token is valid."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'valid-access'
|
||||
mock_token.refresh_token = 'valid-refresh'
|
||||
mock_token.access_token_expires_at = current_time + 1000
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
await store.store_tokens(
|
||||
access_token='valid-access',
|
||||
refresh_token='valid-refresh',
|
||||
access_token_expires_at=current_time + 1000,
|
||||
refresh_token_expires_at=current_time + 10000,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
result = await store.is_access_token_valid()
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_false_for_expired_token(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
self, async_session_maker
|
||||
):
|
||||
"""Test returns False when token is expired."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'expired-access'
|
||||
mock_token.refresh_token = 'valid-refresh'
|
||||
mock_token.access_token_expires_at = current_time - 100 # Expired
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
await store.store_tokens(
|
||||
access_token='expired-access',
|
||||
refresh_token='valid-refresh',
|
||||
access_token_expires_at=current_time - 100, # Expired
|
||||
refresh_token_expires_at=current_time + 10000,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
result = await store.is_access_token_valid()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetInstance:
|
||||
"""Tests for get_instance class method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_creates_auth_token_store(self):
|
||||
async def test_get_instance_creates_auth_token_store(self, async_session_maker):
|
||||
"""Test get_instance creates an AuthTokenStore with correct params."""
|
||||
with patch('storage.auth_token_store.a_session_maker') as mock_a_session_maker:
|
||||
with patch('storage.auth_token_store.a_session_maker', async_session_maker):
|
||||
store = await AuthTokenStore.get_instance(
|
||||
keycloak_user_id='user-123', idp=ProviderType.GITHUB
|
||||
)
|
||||
|
||||
assert store.keycloak_user_id == 'user-123'
|
||||
assert store.idp == ProviderType.GITHUB
|
||||
assert store.a_session_maker is mock_a_session_maker
|
||||
|
||||
|
||||
class TestIdentityProviderValue:
|
||||
"""Tests for identity_provider_value property."""
|
||||
|
||||
def test_identity_provider_value_returns_idp_value(self, auth_token_store):
|
||||
def test_identity_provider_value_returns_idp_value(self):
|
||||
"""Test that identity_provider_value returns the enum value."""
|
||||
assert auth_token_store.identity_provider_value == ProviderType.GITHUB.value
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=ProviderType.GITHUB,
|
||||
)
|
||||
assert store.identity_provider_value == ProviderType.GITHUB.value
|
||||
|
||||
def test_identity_provider_value_for_different_providers(self):
|
||||
"""Test identity_provider_value for different providers."""
|
||||
@@ -644,7 +418,6 @@ class TestIdentityProviderValue:
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=provider,
|
||||
a_session_maker=MagicMock(),
|
||||
)
|
||||
assert store.identity_provider_value == provider.value
|
||||
|
||||
|
||||
@@ -1,33 +1,17 @@
|
||||
"""Unit tests for DeviceCodeStore."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy import select
|
||||
from storage.device_code import DeviceCode
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Mock database session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Mock session maker."""
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = mock_session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device_code_store(mock_session_maker):
|
||||
def device_code_store():
|
||||
"""Create DeviceCodeStore instance."""
|
||||
return DeviceCodeStore(mock_session_maker)
|
||||
return DeviceCodeStore()
|
||||
|
||||
|
||||
class TestDeviceCodeStore:
|
||||
@@ -49,145 +33,257 @@ class TestDeviceCodeStore:
|
||||
assert len(code) == 128
|
||||
assert code.isalnum()
|
||||
|
||||
def test_create_device_code_success(self, device_code_store, mock_session):
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_device_code_success(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test successful device code creation."""
|
||||
# Mock successful creation (no IntegrityError)
|
||||
mock_device_code = MagicMock(spec=DeviceCode)
|
||||
mock_device_code.device_code = 'test-device-code-123'
|
||||
mock_device_code.user_code = 'TESTCODE'
|
||||
|
||||
# Mock the session to return our mock device code after refresh
|
||||
def mock_refresh(obj):
|
||||
obj.device_code = mock_device_code.device_code
|
||||
obj.user_code = mock_device_code.user_code
|
||||
|
||||
mock_session.refresh.side_effect = mock_refresh
|
||||
|
||||
result = device_code_store.create_device_code(expires_in=600)
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.create_device_code(expires_in=600)
|
||||
|
||||
assert isinstance(result, DeviceCode)
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
mock_session.expunge.assert_called_once()
|
||||
assert len(result.device_code) == 128
|
||||
assert len(result.user_code) == 8
|
||||
|
||||
def test_create_device_code_with_retries(
|
||||
self, device_code_store, mock_session_maker
|
||||
# Verify the DeviceCode was created in the database
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(DeviceCode).filter(DeviceCode.device_code == result.device_code)
|
||||
)
|
||||
device_code = result_db.scalars().first()
|
||||
assert device_code is not None
|
||||
assert device_code.user_code == result.user_code
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_device_code_with_retries(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test device code creation with constraint violation retries."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
# First create a device code to cause a collision
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
first_code = await device_code_store.create_device_code(expires_in=600)
|
||||
|
||||
# First attempt fails with IntegrityError, second succeeds
|
||||
mock_session.commit.side_effect = [IntegrityError('', '', ''), None]
|
||||
# Patch generate methods to return the same codes on first attempt,
|
||||
# then different codes on second attempt
|
||||
call_count = {'user': 0, 'device': 0}
|
||||
original_generate_user_code = device_code_store.generate_user_code
|
||||
original_generate_device_code = device_code_store.generate_device_code
|
||||
|
||||
mock_device_code = MagicMock(spec=DeviceCode)
|
||||
mock_device_code.device_code = 'test-device-code-456'
|
||||
mock_device_code.user_code = 'TESTCD2'
|
||||
def mock_generate_user_code():
|
||||
call_count['user'] += 1
|
||||
if call_count['user'] == 1:
|
||||
return first_code.user_code # Collision
|
||||
return original_generate_user_code()
|
||||
|
||||
def mock_refresh(obj):
|
||||
obj.device_code = mock_device_code.device_code
|
||||
obj.user_code = mock_device_code.user_code
|
||||
def mock_generate_device_code():
|
||||
call_count['device'] += 1
|
||||
if call_count['device'] == 1:
|
||||
return first_code.device_code # Collision
|
||||
return original_generate_device_code()
|
||||
|
||||
mock_session.refresh.side_effect = mock_refresh
|
||||
device_code_store.generate_user_code = mock_generate_user_code
|
||||
device_code_store.generate_device_code = mock_generate_device_code
|
||||
|
||||
store = DeviceCodeStore(mock_session_maker)
|
||||
result = store.create_device_code(expires_in=600)
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.create_device_code(expires_in=600)
|
||||
|
||||
assert isinstance(result, DeviceCode)
|
||||
assert mock_session.add.call_count == 2 # Two attempts
|
||||
assert mock_session.commit.call_count == 2 # Two attempts
|
||||
assert result.device_code != first_code.device_code # Should be different
|
||||
assert call_count['user'] == 2 # Two attempts
|
||||
|
||||
def test_create_device_code_max_attempts_exceeded(
|
||||
self, device_code_store, mock_session_maker
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_device_code_max_attempts_exceeded(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test device code creation failure after max attempts."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session_maker.return_value.__exit__.return_value = None
|
||||
# First create a device code
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
first_code = await device_code_store.create_device_code(expires_in=600)
|
||||
|
||||
# All attempts fail with IntegrityError
|
||||
mock_session.commit.side_effect = IntegrityError('', '', '')
|
||||
# Always return the same codes to cause repeated collisions
|
||||
device_code_store.generate_user_code = lambda: first_code.user_code
|
||||
device_code_store.generate_device_code = lambda: first_code.device_code
|
||||
|
||||
store = DeviceCodeStore(mock_session_maker)
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match='Failed to generate unique device codes after 3 attempts',
|
||||
):
|
||||
await device_code_store.create_device_code(
|
||||
expires_in=600, max_attempts=3
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match='Failed to generate unique device codes after 3 attempts',
|
||||
):
|
||||
store.create_device_code(expires_in=600, max_attempts=3)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_device_code(self, device_code_store, async_session_maker):
|
||||
"""Test getting device code by device code."""
|
||||
# Create a device code first
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
result = await device_code_store.get_by_device_code(created.device_code)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'lookup_method,lookup_field',
|
||||
[
|
||||
('get_by_device_code', 'device_code'),
|
||||
('get_by_user_code', 'user_code'),
|
||||
],
|
||||
)
|
||||
def test_lookup_methods(
|
||||
self, device_code_store, mock_session, lookup_method, lookup_field
|
||||
assert result is not None
|
||||
assert result.device_code == created.device_code
|
||||
assert result.user_code == created.user_code
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_device_code_not_found(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test device code lookup methods."""
|
||||
test_code = 'test-code-123'
|
||||
mock_device_code = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_device_code
|
||||
)
|
||||
"""Test getting non-existent device code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.get_by_device_code('non-existent-code')
|
||||
|
||||
result = getattr(device_code_store, lookup_method)(test_code)
|
||||
assert result is None
|
||||
|
||||
assert result == mock_device_code
|
||||
mock_session.query.assert_called_once_with(DeviceCode)
|
||||
mock_session.query.return_value.filter_by.assert_called_once_with(
|
||||
**{lookup_field: test_code}
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_code(self, device_code_store, async_session_maker):
|
||||
"""Test getting device code by user code."""
|
||||
# Create a device code first
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
result = await device_code_store.get_by_user_code(created.user_code)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'device_exists,is_pending,expected_result',
|
||||
[
|
||||
(True, True, True), # Success case
|
||||
(False, True, False), # Device not found
|
||||
(True, False, False), # Device not pending
|
||||
],
|
||||
)
|
||||
def test_authorize_device_code(
|
||||
self,
|
||||
device_code_store,
|
||||
mock_session,
|
||||
device_exists,
|
||||
is_pending,
|
||||
expected_result,
|
||||
assert result is not None
|
||||
assert result.device_code == created.device_code
|
||||
assert result.user_code == created.user_code
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_code_not_found(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test device code authorization."""
|
||||
user_code = 'ABC12345'
|
||||
"""Test getting non-existent user code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.get_by_user_code('NOTFOUND')
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_device_code_success(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test successful device code authorization."""
|
||||
user_id = 'test-user-123'
|
||||
|
||||
if device_exists:
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = is_pending
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_device
|
||||
else:
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
result = device_code_store.authorize_device_code(user_code, user_id)
|
||||
|
||||
assert result == expected_result
|
||||
if expected_result:
|
||||
mock_device.authorize.assert_called_once_with(user_id)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_deny_device_code(self, device_code_store, mock_session):
|
||||
"""Test device code denial."""
|
||||
user_code = 'ABC12345'
|
||||
mock_device = MagicMock()
|
||||
mock_device.is_pending.return_value = True
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_device
|
||||
)
|
||||
|
||||
result = device_code_store.deny_device_code(user_code)
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
result = await device_code_store.authorize_device_code(
|
||||
created.user_code, user_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_device.deny.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify the device code was authorized in the database
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(DeviceCode).filter(DeviceCode.user_code == created.user_code)
|
||||
)
|
||||
device_code = result_db.scalars().first()
|
||||
assert device_code.status == 'authorized'
|
||||
assert device_code.keycloak_user_id == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_device_code_not_found(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test authorizing non-existent device code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.authorize_device_code(
|
||||
'NOTFOUND', 'user-123'
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_device_code_not_pending(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test authorizing already authorized device code."""
|
||||
user_id = 'test-user-123'
|
||||
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
# First authorization
|
||||
await device_code_store.authorize_device_code(created.user_code, user_id)
|
||||
# Second authorization should fail
|
||||
result = await device_code_store.authorize_device_code(
|
||||
created.user_code, 'another-user'
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_device_code_success(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test successful device code denial."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
result = await device_code_store.deny_device_code(created.user_code)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the device code was denied in the database
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(DeviceCode).filter(DeviceCode.user_code == created.user_code)
|
||||
)
|
||||
device_code = result_db.scalars().first()
|
||||
assert device_code.status == 'denied'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_device_code_not_found(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test denying non-existent device code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.deny_device_code('NOTFOUND')
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_device_code_not_pending(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test denying already denied device code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
# First denial
|
||||
await device_code_store.deny_device_code(created.user_code)
|
||||
# Second denial should fail
|
||||
result = await device_code_store.deny_device_code(created.user_code)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_poll_time_success(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test updating poll time."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
created = await device_code_store.create_device_code(expires_in=600)
|
||||
original_interval = created.current_interval
|
||||
result = await device_code_store.update_poll_time(
|
||||
created.device_code, increase_interval=True
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the poll time was updated
|
||||
async with async_session_maker() as session:
|
||||
result_db = await session.execute(
|
||||
select(DeviceCode).filter(DeviceCode.device_code == created.device_code)
|
||||
)
|
||||
device_code = result_db.scalars().first()
|
||||
assert device_code.current_interval > original_interval
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_poll_time_not_found(
|
||||
self, device_code_store, async_session_maker
|
||||
):
|
||||
"""Test updating poll time for non-existent device code."""
|
||||
with patch('storage.device_code_store.a_session_maker', async_session_maker):
|
||||
result = await device_code_store.update_poll_time(
|
||||
'non-existent-code', increase_interval=False
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@@ -9,16 +9,35 @@ from storage.base import Base
|
||||
from storage.gitlab_webhook import GitlabWebhook
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
# Use module-scoped engine to share database across fixtures
|
||||
_test_engine = None
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
async def async_engine(event_loop):
|
||||
"""Create an async SQLite engine for testing.
|
||||
|
||||
This fixture creates an in-memory SQLite database and ensures
|
||||
all tables are created before tests run.
|
||||
"""
|
||||
global _test_engine
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
_test_engine = engine
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
@@ -29,7 +48,7 @@ async def async_engine():
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope='function')
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker for testing."""
|
||||
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
@@ -37,8 +56,21 @@ async def async_session_maker(async_engine):
|
||||
|
||||
@pytest.fixture
|
||||
async def webhook_store(async_session_maker):
|
||||
"""Create a GitlabWebhookStore instance for testing."""
|
||||
return GitlabWebhookStore(a_session_maker=async_session_maker)
|
||||
"""Create a GitlabWebhookStore instance for testing.
|
||||
|
||||
This fixture injects the test's async_session_maker to ensure
|
||||
the store uses the same in-memory database as the test fixtures.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
|
||||
store = GitlabWebhookStore()
|
||||
|
||||
# Inject the test session maker - this needs to replace the module-level import
|
||||
import storage.gitlab_webhook_store as store_module
|
||||
|
||||
store_module.a_session_maker = async_session_maker
|
||||
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -102,7 +134,7 @@ class TestGetWebhookByResourceOnly:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project_webhook_by_resource_only(
|
||||
self, webhook_store, async_session_maker, sample_webhooks
|
||||
self, webhook_store, sample_webhooks
|
||||
):
|
||||
"""Test getting a project webhook by resource ID without user_id filter."""
|
||||
# Arrange
|
||||
|
||||
232
enterprise/tests/unit/storage/test_jira_integration_store.py
Normal file
232
enterprise/tests/unit/storage/test_jira_integration_store.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Tests for JiraIntegrationStore async methods.
|
||||
|
||||
The store uses async database sessions (a_session_maker) for all operations,
|
||||
which is critical for avoiding asyncpg event loop issues when called from
|
||||
FastAPI async endpoints.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
"""Create a JiraIntegrationStore instance."""
|
||||
return JiraIntegrationStore()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_async_session():
|
||||
"""Factory to create properly mocked async session context manager."""
|
||||
|
||||
def _create(query_result=None, all_results=None):
|
||||
mock_session = Mock()
|
||||
mock_result = Mock()
|
||||
|
||||
if all_results is not None:
|
||||
mock_result.scalars.return_value.all.return_value = all_results
|
||||
else:
|
||||
mock_result.scalars.return_value.first.return_value = query_result
|
||||
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.refresh = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_context_manager():
|
||||
yield mock_session
|
||||
|
||||
return mock_context_manager, mock_session
|
||||
|
||||
return _create
|
||||
|
||||
|
||||
class TestJiraIntegrationStoreAsyncMethods:
|
||||
"""Tests verifying JiraIntegrationStore methods use async sessions correctly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_by_id_returns_workspace(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test get_workspace_by_id returns workspace when found."""
|
||||
# Arrange
|
||||
mock_workspace = Mock(spec=JiraWorkspace)
|
||||
mock_workspace.id = 1
|
||||
mock_workspace.name = 'test-workspace'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session(mock_workspace)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
result = await store.get_workspace_by_id(1)
|
||||
|
||||
# Assert
|
||||
assert result == mock_workspace
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_by_id_returns_none_when_not_found(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test get_workspace_by_id returns None when workspace not found."""
|
||||
# Arrange
|
||||
mock_context_manager, mock_session = create_mock_async_session(None)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
result = await store.get_workspace_by_id(999)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_by_name_normalizes_to_lowercase(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test get_workspace_by_name converts name to lowercase for query."""
|
||||
# Arrange
|
||||
mock_workspace = Mock(spec=JiraWorkspace)
|
||||
mock_workspace.name = 'test-workspace'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session(mock_workspace)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
result = await store.get_workspace_by_name('TEST-WORKSPACE')
|
||||
|
||||
# Assert
|
||||
assert result == mock_workspace
|
||||
# Verify the query was executed (filter includes lowercase conversion)
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_user_filters_by_status(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test get_active_user only returns users with active status."""
|
||||
# Arrange
|
||||
mock_user = Mock(spec=JiraUser)
|
||||
mock_user.jira_user_id = 'jira-123'
|
||||
mock_user.jira_workspace_id = 1
|
||||
mock_user.status = 'active'
|
||||
|
||||
mock_context_manager, mock_session = create_mock_async_session(mock_user)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
result = await store.get_active_user('jira-123', 1)
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workspace_adds_and_commits(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test create_workspace properly adds, commits, and refreshes."""
|
||||
# Arrange
|
||||
mock_context_manager, mock_session = create_mock_async_session(None)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
await store.create_workspace(
|
||||
name='TEST-WORKSPACE',
|
||||
jira_cloud_id='cloud-123',
|
||||
admin_user_id='admin-user',
|
||||
encrypted_webhook_secret='encrypted-secret',
|
||||
svc_acc_email='svc@test.com',
|
||||
encrypted_svc_acc_api_key='encrypted-key',
|
||||
status='active',
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
# Verify workspace was created with lowercase name
|
||||
added_workspace = mock_session.add.call_args[0][0]
|
||||
assert added_workspace.name == 'test-workspace'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_integration_status_raises_if_not_found(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test update_user_integration_status raises ValueError if user not found."""
|
||||
# Arrange
|
||||
mock_context_manager, mock_session = create_mock_async_session(None)
|
||||
|
||||
# Act & Assert
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await store.update_user_integration_status('unknown-user', 'inactive')
|
||||
|
||||
assert 'Jira user not found' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_workspace_deactivates_all_users(
|
||||
self, store, create_mock_async_session
|
||||
):
|
||||
"""Test deactivate_workspace sets all users and workspace to inactive."""
|
||||
# Arrange
|
||||
mock_user1 = Mock(spec=JiraUser)
|
||||
mock_user1.status = 'active'
|
||||
mock_user2 = Mock(spec=JiraUser)
|
||||
mock_user2.status = 'active'
|
||||
|
||||
mock_workspace = Mock(spec=JiraWorkspace)
|
||||
mock_workspace.status = 'active'
|
||||
|
||||
mock_session = Mock()
|
||||
|
||||
# First execute returns users, second returns workspace
|
||||
call_count = [0]
|
||||
|
||||
def execute_side_effect(*args, **kwargs):
|
||||
result = Mock()
|
||||
if call_count[0] == 0:
|
||||
result.scalars.return_value.all.return_value = [mock_user1, mock_user2]
|
||||
else:
|
||||
result.scalars.return_value.first.return_value = mock_workspace
|
||||
call_count[0] += 1
|
||||
return result
|
||||
|
||||
mock_session.execute = AsyncMock(side_effect=execute_side_effect)
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_context_manager():
|
||||
yield mock_session
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
'storage.jira_integration_store.a_session_maker', mock_context_manager
|
||||
):
|
||||
await store.deactivate_workspace(1)
|
||||
|
||||
# Assert
|
||||
assert mock_user1.status == 'inactive'
|
||||
assert mock_user2.status == 'inactive'
|
||||
assert mock_workspace.status == 'inactive'
|
||||
mock_session.commit.assert_called_once()
|
||||
183
enterprise/tests/unit/storage/test_org_app_settings_store.py
Normal file
183
enterprise/tests/unit/storage/test_org_app_settings_store.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Unit tests for OrgAppSettingsStore.
|
||||
|
||||
Tests the async database operations for organization app settings.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from server.routes.org_models import OrgAppSettingsUpdate
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_app_settings_store import OrgAppSettingsStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker for testing."""
|
||||
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_org_by_user_id_success(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user exists with a current organization
|
||||
WHEN: get_current_org_by_user_id is called with the user's ID
|
||||
THEN: The organization is returned with correct data
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(
|
||||
name='test-org',
|
||||
enable_proactive_conversation_starters=True,
|
||||
enable_solvability_analysis=False,
|
||||
max_budget_per_task=25.0,
|
||||
)
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
current_org_id=org.id,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
user_id = str(user.id)
|
||||
|
||||
# Act
|
||||
store = OrgAppSettingsStore(db_session=session)
|
||||
result = await store.get_current_org_by_user_id(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.name == 'test-org'
|
||||
assert result.enable_proactive_conversation_starters is True
|
||||
assert result.enable_solvability_analysis is False
|
||||
assert result.max_budget_per_task == 25.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_org_by_user_id_user_not_found(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user does not exist in the database
|
||||
WHEN: get_current_org_by_user_id is called with a non-existent ID
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
|
||||
# Act
|
||||
async with async_session_maker() as session:
|
||||
store = OrgAppSettingsStore(db_session=session)
|
||||
result = await store.get_current_org_by_user_id(non_existent_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_success(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization exists in the database
|
||||
WHEN: update_org_app_settings is called with new values
|
||||
THEN: The organization's settings are updated and returned
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(
|
||||
name='test-org',
|
||||
enable_proactive_conversation_starters=True,
|
||||
enable_solvability_analysis=False,
|
||||
max_budget_per_task=10.0,
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
update_data = OrgAppSettingsUpdate(
|
||||
enable_proactive_conversation_starters=False,
|
||||
enable_solvability_analysis=True,
|
||||
max_budget_per_task=50.0,
|
||||
)
|
||||
|
||||
# Act
|
||||
store = OrgAppSettingsStore(db_session=session)
|
||||
result = await store.update_org_app_settings(org_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.enable_proactive_conversation_starters is False
|
||||
assert result.enable_solvability_analysis is True
|
||||
assert result.max_budget_per_task == 50.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_partial(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization exists with existing settings
|
||||
WHEN: update_org_app_settings is called with only some fields
|
||||
THEN: Only the provided fields are updated, others remain unchanged
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(
|
||||
name='test-org',
|
||||
enable_proactive_conversation_starters=True,
|
||||
enable_solvability_analysis=False,
|
||||
max_budget_per_task=10.0,
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Only update max_budget_per_task
|
||||
update_data = OrgAppSettingsUpdate(max_budget_per_task=100.0)
|
||||
|
||||
# Act
|
||||
store = OrgAppSettingsStore(db_session=session)
|
||||
result = await store.update_org_app_settings(org_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.max_budget_per_task == 100.0
|
||||
assert result.enable_proactive_conversation_starters is True # Unchanged
|
||||
assert result.enable_solvability_analysis is False # Unchanged
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_org_not_found(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization does not exist in the database
|
||||
WHEN: update_org_app_settings is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_id = uuid.uuid4()
|
||||
update_data = OrgAppSettingsUpdate(enable_proactive_conversation_starters=False)
|
||||
|
||||
# Act
|
||||
async with async_session_maker() as session:
|
||||
store = OrgAppSettingsStore(db_session=session)
|
||||
result = await store.update_org_app_settings(non_existent_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
175
enterprise/tests/unit/storage/test_org_llm_settings_store.py
Normal file
175
enterprise/tests/unit/storage/test_org_llm_settings_store.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Unit tests for OrgLLMSettingsStore.
|
||||
|
||||
Tests the async database operations for organization LLM settings.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from server.routes.org_models import OrgLLMSettingsUpdate
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_llm_settings_store import OrgLLMSettingsStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker for testing."""
|
||||
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_org_by_user_id_success(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user exists with a current_org_id
|
||||
WHEN: get_current_org_by_user_id is called
|
||||
THEN: The user's current organization is returned
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org', default_llm_model='claude-3')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
user_id = str(user.id)
|
||||
|
||||
# Act
|
||||
store = OrgLLMSettingsStore(db_session=session)
|
||||
result = await store.get_current_org_by_user_id(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.name == 'test-org'
|
||||
assert result.default_llm_model == 'claude-3'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_org_by_user_id_user_not_found(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user does not exist in the database
|
||||
WHEN: get_current_org_by_user_id is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
|
||||
# Act
|
||||
async with async_session_maker() as session:
|
||||
store = OrgLLMSettingsStore(db_session=session)
|
||||
result = await store.get_current_org_by_user_id(non_existent_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_success(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization exists in the database
|
||||
WHEN: update_org_llm_settings is called with new values
|
||||
THEN: The organization's LLM settings are updated and returned
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org', default_llm_model='old-model')
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
update_data = OrgLLMSettingsUpdate(
|
||||
default_llm_model='new-model',
|
||||
agent='CodeActAgent',
|
||||
confirmation_mode=True,
|
||||
)
|
||||
|
||||
# Act
|
||||
store = OrgLLMSettingsStore(db_session=session)
|
||||
with patch(
|
||||
'storage.org_llm_settings_store.OrgMemberStore.update_all_members_llm_settings_async',
|
||||
AsyncMock(),
|
||||
):
|
||||
result = await store.update_org_llm_settings(org_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.default_llm_model == 'new-model'
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.confirmation_mode is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_org_not_found(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization does not exist in the database
|
||||
WHEN: update_org_llm_settings is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_org_id = uuid.uuid4()
|
||||
update_data = OrgLLMSettingsUpdate(default_llm_model='new-model')
|
||||
|
||||
# Act
|
||||
async with async_session_maker() as session:
|
||||
store = OrgLLMSettingsStore(db_session=session)
|
||||
result = await store.update_org_llm_settings(non_existent_org_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_propagates_to_members(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization exists with update data containing member-relevant settings
|
||||
WHEN: update_org_llm_settings is called
|
||||
THEN: Member settings are propagated via OrgMemberStore
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org', default_llm_model='old-model')
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
update_data = OrgLLMSettingsUpdate(
|
||||
default_llm_model='new-model',
|
||||
llm_api_key='new-api-key',
|
||||
)
|
||||
|
||||
# Act
|
||||
store = OrgLLMSettingsStore(db_session=session)
|
||||
with patch(
|
||||
'storage.org_llm_settings_store.OrgMemberStore.update_all_members_llm_settings_async',
|
||||
AsyncMock(),
|
||||
) as mock_update_members:
|
||||
await store.update_org_llm_settings(org_id, update_data)
|
||||
|
||||
# Assert
|
||||
mock_update_members.assert_called_once()
|
||||
call_args = mock_update_members.call_args
|
||||
member_settings = call_args[0][2]
|
||||
assert member_settings.llm_model == 'new-model'
|
||||
assert member_settings.llm_api_key == 'new-api-key'
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user