Compare commits

...

9 Commits
main ... APP-19

Author SHA1 Message Date
amanape
6776b8b34e Fix middleware to properly trigger authentication using get_user_auth 2025-10-01 00:14:44 +04:00
amanape
14f50a4816 Fix pre-commit formatting issues 2025-09-30 23:51:16 +04:00
amanape
fa4278b3b0 Fix mypy error and formatting issues 2025-09-30 23:35:11 +04:00
amanape
4b2dfc1f0d Fix linting issues in middleware docstrings 2025-09-30 23:29:06 +04:00
amanape
24d5a33b0e Merge branch 'main' into APP-19 2025-09-30 23:17:22 +04:00
amanape
ba1edf0f0a Add debug logging to LLM middleware and fix middleware order 2025-09-30 23:14:13 +04:00
amanape
741ea28cd3 Improve middleware JSON parsing and add detailed logging 2025-09-30 22:17:23 +04:00
amanape
191e3b9cef Add enterprise LLM settings middleware with pro user validation 2025-09-30 22:09:20 +04:00
amanape
eac39e85eb revert 2025-09-30 21:29:35 +04:00
110 changed files with 837 additions and 615 deletions

View File

@@ -17,8 +17,7 @@ class SaaSExperimentManager(ExperimentManager):
def run_conversation_variant_test(
user_id, conversation_id, conversation_settings
) -> ConversationInitData:
"""
Run conversation variant test and potentially modify the conversation settings
"""Run conversation variant test and potentially modify the conversation settings
based on the PostHog feature flags.
Args:
@@ -53,8 +52,7 @@ class SaaSExperimentManager(ExperimentManager):
def run_config_variant_test(
user_id: str | None, conversation_id: str, config: OpenHandsConfig
) -> OpenHandsConfig:
"""
Run agent config variant test and potentially modify the OpenHands config
"""Run agent config variant test and potentially modify the OpenHands config
based on the current experiment type and PostHog feature flags.
Args:

View File

@@ -1,5 +1,4 @@
"""
LiteLLM model experiment handler.
"""LiteLLM model experiment handler.
This module contains the handler for the LiteLLM model experiment.
"""
@@ -18,8 +17,7 @@ from openhands.core.logger import openhands_logger as logger
def handle_litellm_default_model_experiment(
user_id, conversation_id, conversation_settings
):
"""
Handle the LiteLLM model experiment.
"""Handle the LiteLLM model experiment.
Args:
user_id: The user ID

View File

@@ -1,5 +1,4 @@
"""
System prompt experiment handler.
"""System prompt experiment handler.
This module contains the handler for the system prompt experiment that uses
the PostHog variant as the system prompt filename.
@@ -17,8 +16,7 @@ from openhands.core.logger import openhands_logger as logger
def _get_system_prompt_variant(user_id, conversation_id):
"""
Get the system prompt variant for the experiment.
"""Get the system prompt variant for the experiment.
Args:
user_id: The user ID
@@ -119,8 +117,7 @@ def _get_system_prompt_variant(user_id, conversation_id):
def handle_system_prompt_experiment(
user_id, conversation_id, config: OpenHandsConfig
) -> OpenHandsConfig:
"""
Handle the system prompt experiment for OpenHands config.
"""Handle the system prompt experiment for OpenHands config.
Args:
user_id: The user ID

View File

@@ -1,5 +1,4 @@
"""
LiteLLM model experiment handler.
"""LiteLLM model experiment handler.
This module contains the handler for the LiteLLM model experiment.
"""
@@ -110,8 +109,7 @@ def handle_claude4_vs_gpt5_experiment(
conversation_id: str,
conversation_settings: ConversationInitData,
) -> ConversationInitData:
"""
Handle the LiteLLM model experiment.
"""Handle the LiteLLM model experiment.
Args:
user_id: The user ID
@@ -121,7 +119,6 @@ def handle_claude4_vs_gpt5_experiment(
Returns:
Modified conversation settings
"""
enabled_variant = _get_model_variant(user_id, conversation_id)
if not enabled_variant:

View File

@@ -1,5 +1,4 @@
"""
Condenser max step experiment handler.
"""Condenser max step experiment handler.
This module contains the handler for the condenser max step experiment that tests
different max_size values for the condenser configuration.
@@ -15,8 +14,7 @@ from openhands.server.session.conversation_init_data import ConversationInitData
def _get_condenser_max_step_variant(user_id, conversation_id):
"""
Get the condenser max step variant for the experiment.
"""Get the condenser max step variant for the experiment.
Args:
user_id: The user ID
@@ -119,8 +117,7 @@ def handle_condenser_max_step_experiment(
conversation_id: str,
conversation_settings: ConversationInitData,
) -> ConversationInitData:
"""
Handle the condenser max step experiment for conversation settings.
"""Handle the condenser max step experiment for conversation settings.
We should not modify persistent user settings. Instead, apply the experiment
variant to the conversation's in-memory settings object for this session only.
@@ -131,7 +128,6 @@ def handle_condenser_max_step_experiment(
Returns the (potentially) modified conversation_settings.
"""
enabled_variant = _get_condenser_max_step_variant(user_id, conversation_id)
if enabled_variant is None:

View File

@@ -1,5 +1,4 @@
"""
Experiment versions package.
"""Experiment versions package.
This package contains handlers for different experiment versions.
"""

View File

@@ -43,8 +43,7 @@ class TriggerType(str, Enum):
class GitHubDataCollector:
"""
Saves data on Cloud Resolver Interactions
"""Saves data on Cloud Resolver Interactions
1. We always save
- Resolver trigger (comment or label)
@@ -89,8 +88,7 @@ class GitHubDataCollector:
self.conversation_id = None
async def _get_repo_node_id(self, repo_id: str, gh_client) -> str:
"""
Get the new GitHub GraphQL node ID for a repository using the GitHub client.
"""Get the new GitHub GraphQL node ID for a repository using the GitHub client.
Args:
repo_id: Numeric repository ID as string (e.g., "123456789")
@@ -136,10 +134,7 @@ class GitHubDataCollector:
def _get_issue_comments(
self, installation_id: str, repo_name: str, issue_number: int, conversation_id
) -> list[dict[str, Any]]:
"""
Retrieve all comments from an issue until a comment with conversation_id is found
"""
"""Retrieve all comments from an issue until a comment with conversation_id is found"""
try:
installation_token = self._get_installation_access_token(installation_id)
@@ -175,18 +170,16 @@ class GitHubDataCollector:
github_view: GithubIssue,
trigger_type: TriggerType,
) -> None:
"""
Save issue data when it's labeled with openhands
"""Save issue data when it's labeled with openhands
1. Save under {conversation_dir}/{conversation_id}/github_data/issue_{issue_number}.json
2. Save issue snapshot (title, body, comments)
3. Save trigger type (label)
4. Save PR opened (if exists, this information comes later when agent has finished its task)
- Save commit shas
- Save author info
5. Was PR merged or closed
1. Save under {conversation_dir}/{conversation_id}/github_data/issue_{issue_number}.json
2. Save issue snapshot (title, body, comments)
3. Save trigger type (label)
4. Save PR opened (if exists, this information comes later when agent has finished its task)
- Save commit shas
- Save author info
5. Was PR merged or closed
"""
conversation_id = github_view.conversation_id
if not conversation_id:
@@ -385,7 +378,6 @@ class GitHubDataCollector:
openhands_general_comment_count: int = 0,
) -> dict:
"""Build the final data structure for JSON storage"""
is_merged = pr_data['merged']
merged_by = None
merge_commit_sha = None
@@ -419,8 +411,7 @@ class GitHubDataCollector:
}
async def save_full_pr(self, openhands_pr: OpenhandsPR) -> None:
"""
Save PR information including metadata and commit details using GraphQL
"""Save PR information including metadata and commit details using GraphQL
Saves:
- Repo metadata (repo name, languages, contributors)
@@ -606,17 +597,12 @@ class GitHubDataCollector:
return None
def _is_pr_closed_or_merged(self, payload):
"""
Check if PR was closed (regardless of conversation URL)
"""
"""Check if PR was closed (regardless of conversation URL)"""
action = payload.get('action', '')
return action == 'closed' and 'pull_request' in payload
def _track_closed_or_merged_pr(self, payload):
"""
Track PR closed/merged event
"""
"""Track PR closed/merged event"""
repo_id = str(payload['repository']['id'])
pr_number = payload['number']
installation_id = str(payload['installation']['id'])

View File

@@ -103,8 +103,7 @@ class SaaSGitHubService(GitHubService):
}
async def get_repository_node_id(self, repo_id: str) -> str:
"""
Get the new GitHub GraphQL node ID for a repository using REST API.
"""Get the new GitHub GraphQL node ID for a repository using REST API.
Args:
repo_id: Numeric repository ID as string (e.g., "123456789")

View File

@@ -39,7 +39,6 @@ def fetch_github_issue_context(
Returns:
A comprehensive string containing the issue/PR context
"""
# Build context string
context_parts = []

View File

@@ -55,7 +55,6 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
This function checks both the global environment variable kill switch AND
the user's individual setting. Both must be true for the function to return true.
"""
# If no user ID is provided, we can't check user settings
if not user_id:
return False

View File

@@ -43,8 +43,7 @@ class GitlabManager(Manager):
async def _user_has_write_access_to_repo(
self, project_id: str, user_id: str
) -> bool:
"""
Check if the user has write access to the repository (can pull/push changes and open merge requests).
"""Check if the user has write access to the repository (can pull/push changes and open merge requests).
Args:
project_id: The ID of the GitLab project
@@ -54,7 +53,6 @@ class GitlabManager(Manager):
Returns:
bool: True if the user has write access to the repository, False otherwise
"""
keycloak_user_id = await self.token_manager.get_user_id_from_idp_user_id(
user_id, ProviderType.GITLAB
)
@@ -117,8 +115,7 @@ class GitlabManager(Manager):
return has_write_access
async def send_message(self, message: Message, gitlab_view: ResolverViewInterface):
"""
Send a message to GitLab based on the view type.
"""Send a message to GitLab based on the view type.
Args:
message: The message to send
@@ -165,8 +162,7 @@ class GitlabManager(Manager):
)
async def start_job(self, gitlab_view: GitlabViewType):
"""
Start a job for the GitLab view.
"""Start a job for the GitLab view.
Args:
gitlab_view: The GitLab view object containing issue/PR/comment info

View File

@@ -81,8 +81,7 @@ class SaaSGitLabService(GitLabService):
return gitlab_token
async def get_owned_groups(self) -> list[dict]:
"""
Get all groups for which the current user is the owner.
"""Get all groups for which the current user is the owner.
Returns:
list[dict]: A list of groups owned by the current user.
@@ -98,8 +97,7 @@ class SaaSGitLabService(GitLabService):
return []
async def add_owned_projects_and_groups_to_db(self, owned_personal_projects):
"""
Add owned projects and groups to the database for webhook tracking.
"""Add owned projects and groups to the database for webhook tracking.
Args:
owned_personal_projects: List of personal projects owned by the user
@@ -147,8 +145,7 @@ class SaaSGitLabService(GitLabService):
async def store_repository_data(
self, users_personal_projects: list[dict], repositories: list[Repository]
) -> None:
"""
Store repository data in the database.
"""Store repository data in the database.
This function combines the functionality of add_owned_projects_and_groups_to_db and store_repositories_in_db.
Args:
@@ -171,8 +168,7 @@ class SaaSGitLabService(GitLabService):
async def get_all_repositories(
self, sort: str, app_mode: AppMode, store_in_background: bool = True
) -> list[Repository]:
"""
Get repositories for the authenticated user, including information about the kind of project.
"""Get repositories for the authenticated user, including information about the kind of project.
Also collects repositories where the kind is "user" and the user is the owner.
Args:
@@ -270,8 +266,7 @@ class SaaSGitLabService(GitLabService):
async def check_resource_exists(
self, resource_type: GitLabResourceType, resource_id: str
) -> tuple[bool, WebhookStatus | None]:
"""
Check if resource exists and the user has access to it.
"""Check if resource exists and the user has access to it.
Args:
resource_type: The type of resource
@@ -282,7 +277,6 @@ class SaaSGitLabService(GitLabService):
- bool: True if the resource exists and the user has access to it, False otherwise
- str: A reason message explaining the result
"""
if resource_type == GitLabResourceType.GROUP:
url = f'{self.BASE_URL}/groups/{resource_id}'
else:
@@ -301,8 +295,7 @@ class SaaSGitLabService(GitLabService):
async def check_webhook_exists_on_resource(
self, resource_type: GitLabResourceType, resource_id: str, webhook_url: str
) -> tuple[bool, WebhookStatus | None]:
"""
Check if a webhook already exists for resource with a specific URL.
"""Check if a webhook already exists for resource with a specific URL.
Args:
resource_type: The type of resource
@@ -314,7 +307,6 @@ class SaaSGitLabService(GitLabService):
- bool: True if the webhook exists, False otherwise
- str: A reason message explaining the result
"""
# Construct the URL based on the resource type
if resource_type == GitLabResourceType.GROUP:
url = f'{self.BASE_URL}/groups/{resource_id}/hooks'
@@ -343,8 +335,7 @@ class SaaSGitLabService(GitLabService):
async def check_user_has_admin_access_to_resource(
self, resource_type: GitLabResourceType, resource_id: str
) -> tuple[bool, WebhookStatus | None]:
"""
Check if the user has admin access to resource (is either an owner or maintainer)
"""Check if the user has admin access to resource (is either an owner or maintainer)
Args:
resource_type: The type of resource
@@ -355,7 +346,6 @@ class SaaSGitLabService(GitLabService):
- bool: True if the user has admin access to the resource (owner or maintainer), False otherwise
- str: A reason message explaining the result
"""
# For groups, we need to check if the user is an owner or maintainer
if resource_type == GitLabResourceType.GROUP:
url = f'{self.BASE_URL}/groups/{resource_id}/members/all'
@@ -412,8 +402,7 @@ class SaaSGitLabService(GitLabService):
webhook_uuid: str,
scopes: list[str],
) -> tuple[str | None, WebhookStatus | None]:
"""
Install webhook for user's group or project
"""Install webhook for user's group or project
Args:
resource_type: The type of resource
@@ -428,7 +417,6 @@ class SaaSGitLabService(GitLabService):
- bool: True if installation was successful, False otherwise
- str: A reason message explaining the result
"""
description = 'Cloud OpenHands Resolver'
# Set up webhook parameters
@@ -500,9 +488,7 @@ class SaaSGitLabService(GitLabService):
async def reply_to_issue(
self, project_id: str, issue_number: str, discussion_id: str | None, body: str
):
"""
Either create new comment thread, or reply to comment thread (depending on discussion_id param)
"""
"""Either create new comment thread, or reply to comment thread (depending on discussion_id param)"""
try:
if discussion_id:
url = f'{self.BASE_URL}/projects/{project_id}/issues/{issue_number}/discussions/{discussion_id}/notes'
@@ -517,9 +503,7 @@ class SaaSGitLabService(GitLabService):
async def reply_to_mr(
self, project_id: str, merge_request_iid: str, discussion_id: str, body: str
):
"""
Reply to comment thread on MR
"""
"""Reply to comment thread on MR"""
try:
url = f'{self.BASE_URL}/projects/{project_id}/merge_requests/{merge_request_iid}/discussions/{discussion_id}/notes'
params = {'body': body}

View File

@@ -48,7 +48,6 @@ class JiraManager(Manager):
self, jira_user_id: str, workspace_id: int
) -> tuple[JiraUser | None, UserAuth | None]:
"""Authenticate Jira user and get their OpenHands user auth."""
# Find active Jira user by Keycloak user ID and workspace ID
jira_user = await self.integration_store.get_active_user(
jira_user_id, workspace_id
@@ -206,7 +205,6 @@ class JiraManager(Manager):
async def receive_message(self, message: Message):
"""Process incoming Jira webhook message."""
payload = message.message.get('payload', {})
job_context = self.parse_webhook(payload)
@@ -299,10 +297,7 @@ class JiraManager(Manager):
async def is_job_requested(
self, message: Message, jira_view: JiraViewInterface
) -> bool:
"""
Check if a job is requested and handle repository selection.
"""
"""Check if a job is requested and handle repository selection."""
if isinstance(jira_view, JiraExistingConversationView):
return True

View File

@@ -35,7 +35,6 @@ class JiraNewConversationView(JiraViewInterface):
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
instructions_template = jinja_env.get_template('jira_instructions.j2')
instructions = instructions_template.render()
@@ -52,7 +51,6 @@ class JiraNewConversationView(JiraViewInterface):
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
"""Create a new Jira conversation"""
if not self.selected_repo:
raise StartingConvoException('No repository selected for this conversation')
@@ -112,7 +110,6 @@ class JiraExistingConversationView(JiraViewInterface):
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
user_msg_template = jinja_env.get_template('jira_existing_conversation.j2')
user_msg = user_msg_template.render(
issue_key=self.job_context.issue_key,
@@ -125,7 +122,6 @@ class JiraExistingConversationView(JiraViewInterface):
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
"""Update an existing Jira conversation"""
user_id = self.jira_user.keycloak_user_id
try:
@@ -191,7 +187,6 @@ class JiraFactory:
jira_workspace: JiraWorkspace,
) -> JiraViewInterface:
"""Create appropriate Jira view based on the message and user state"""
if not jira_user or not saas_user_auth or not jira_workspace:
raise StartingConvoException('User not authenticated with Jira integration')

View File

@@ -48,7 +48,6 @@ class JiraDcManager(Manager):
self, user_email: str, jira_dc_user_id: str, workspace_id: int
) -> tuple[JiraDcUser | None, UserAuth | None]:
"""Authenticate Jira DC user and get their OpenHands user auth."""
if not jira_dc_user_id or jira_dc_user_id == 'none':
# Get Keycloak user ID from email
keycloak_user_id = await self.token_manager.get_user_id_from_user_email(
@@ -221,7 +220,6 @@ class JiraDcManager(Manager):
async def receive_message(self, message: Message):
"""Process incoming Jira DC webhook message."""
payload = message.message.get('payload', {})
job_context = self.parse_webhook(payload)
@@ -315,10 +313,7 @@ class JiraDcManager(Manager):
async def is_job_requested(
self, message: Message, jira_dc_view: JiraDcViewInterface
) -> bool:
"""
Check if a job is requested and handle repository selection.
"""
"""Check if a job is requested and handle repository selection."""
if isinstance(jira_dc_view, JiraDcExistingConversationView):
return True

View File

@@ -38,7 +38,6 @@ class JiraDcNewConversationView(JiraDcViewInterface):
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
instructions_template = jinja_env.get_template('jira_dc_instructions.j2')
instructions = instructions_template.render()
@@ -55,7 +54,6 @@ class JiraDcNewConversationView(JiraDcViewInterface):
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
"""Create a new Jira DC conversation"""
if not self.selected_repo:
raise StartingConvoException('No repository selected for this conversation')
@@ -115,7 +113,6 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
user_msg_template = jinja_env.get_template('jira_dc_existing_conversation.j2')
user_msg = user_msg_template.render(
issue_key=self.job_context.issue_key,
@@ -128,7 +125,6 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
"""Update an existing Jira conversation"""
user_id = self.jira_dc_user.keycloak_user_id
try:
@@ -195,7 +191,6 @@ class JiraDcFactory:
jira_dc_workspace: JiraDcWorkspace,
) -> JiraDcViewInterface:
"""Create appropriate Jira DC view based on the payload."""
if not jira_dc_user or not saas_user_auth or not jira_dc_workspace:
raise StartingConvoException('User not authenticated with Jira integration')

View File

@@ -46,7 +46,6 @@ class LinearManager(Manager):
self, linear_user_id: str, workspace_id: int
) -> tuple[LinearUser | None, UserAuth | None]:
"""Authenticate Linear user and get their OpenHands user auth."""
# Find active Linear user by Linear user ID and workspace ID
linear_user = await self.integration_store.get_active_user(
linear_user_id, workspace_id
@@ -305,10 +304,7 @@ class LinearManager(Manager):
async def is_job_requested(
self, message: Message, linear_view: LinearViewInterface
) -> bool:
"""
Check if a job is requested and handle repository selection.
"""
"""Check if a job is requested and handle repository selection."""
if isinstance(linear_view, LinearExistingConversationView):
return True

View File

@@ -35,7 +35,6 @@ class LinearNewConversationView(LinearViewInterface):
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
instructions_template = jinja_env.get_template('linear_instructions.j2')
instructions = instructions_template.render()
@@ -52,7 +51,6 @@ class LinearNewConversationView(LinearViewInterface):
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
"""Create a new Linear conversation"""
if not self.selected_repo:
raise StartingConvoException('No repository selected for this conversation')
@@ -112,7 +110,6 @@ class LinearExistingConversationView(LinearViewInterface):
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"""Instructions passed when conversation is first initialized"""
user_msg_template = jinja_env.get_template('linear_existing_conversation.j2')
user_msg = user_msg_template.render(
issue_key=self.job_context.issue_key,
@@ -125,7 +122,6 @@ class LinearExistingConversationView(LinearViewInterface):
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
"""Update an existing Linear conversation"""
user_id = self.linear_user.keycloak_user_id
try:
@@ -192,7 +188,6 @@ class LinearFactory:
linear_workspace: LinearWorkspace,
) -> LinearViewInterface:
"""Create appropriate Linear view based on the message and user state"""
if not linear_user or not saas_user_auth or not linear_workspace:
raise StartingConvoException(
'User not authenticated with Linear integration'

View File

@@ -8,22 +8,22 @@ class Manager(ABC):
@abstractmethod
async def receive_message(self, message: Message):
"Receive message from integration"
"""Receive message from integration"""
raise NotImplementedError
@abstractmethod
def send_message(self, message: Message):
"Send message to integration from Openhands server"
"""Send message to integration from Openhands server"""
raise NotImplementedError
@abstractmethod
async def is_job_requested(self, message: Message) -> bool:
"Confirm that a job is being requested"
"""Confirm that a job is being requested"""
raise NotImplementedError
@abstractmethod
def start_job(self):
"Kick off a job with openhands agent"
"""Kick off a job with openhands agent"""
raise NotImplementedError
def create_outgoing_message(self, msg: str | dict, ephemeral: bool = False):

View File

@@ -244,13 +244,11 @@ class SlackManager(Manager):
async def is_job_requested(
self, message: Message, slack_view: SlackViewInterface
) -> bool:
"""
A job is always request we only receive webhooks for events associated with the slack bot
"""A job is always request we only receive webhooks for events associated with the slack bot
This method really just checks
1. Is the user is authenticated
2. Do we have the necessary information to start a job (either by inferring the selected repo, otherwise asking the user)
"""
# Infer repo from user message is not needed; user selected repo from the form or is updating existing convo
if isinstance(slack_view, SlackUpdateExistingConversationView):
return True

View File

@@ -24,17 +24,17 @@ class SlackViewInterface(SummaryExtractionTracker, ABC):
@abstractmethod
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"Instructions passed when conversation is first initialized"
"""Instructions passed when conversation is first initialized"""
pass
@abstractmethod
async def create_or_update_conversation(self, jinja_env: Environment):
"Create a new conversation"
"""Create a new conversation"""
pass
@abstractmethod
def get_callback_id(self) -> str:
"Unique callback id for subscribription made to EventStream for fetching agent summary"
"""Unique callback id for subscribription made to EventStream for fetching agent summary"""
pass
@abstractmethod
@@ -43,6 +43,4 @@ class SlackViewInterface(SummaryExtractionTracker, ABC):
class StartingConvoException(Exception):
"""
Raised when trying to send message to a conversation that's is still starting up
"""
"""Raised when trying to send message to a conversation that's is still starting up"""

View File

@@ -95,8 +95,7 @@ class SlackNewConversationView(SlackViewInterface):
return ''
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"Instructions passed when conversation is first initialized"
"""Instructions passed when conversation is first initialized"""
user_info: SlackUser = self.slack_to_openhands_user
messages = []
@@ -179,9 +178,7 @@ class SlackNewConversationView(SlackViewInterface):
await slack_conversation_store.create_slack_conversation(slack_conversation)
async def create_or_update_conversation(self, jinja: Environment) -> str:
"""
Only creates a new conversation
"""
"""Only creates a new conversation"""
self._verify_necessary_values_are_set()
provider_tokens = await self.saas_user_auth.get_provider_tokens()
@@ -246,9 +243,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
return user_message, ''
async def create_or_update_conversation(self, jinja: Environment) -> str:
"""
Send new user message to converation
"""
"""Send new user message to converation"""
user_info: SlackUser = self.slack_to_openhands_user
saas_user_auth: UserAuth = self.saas_user_auth
user_id = user_info.keycloak_user_id

View File

@@ -1,5 +1,4 @@
"""
Utilities for loading and managing pre-trained classifiers.
"""Utilities for loading and managing pre-trained classifiers.
Assumes that classifiers are stored adjacent to this file in the `solvability/data` directory, using a simple
`name + .json` pattern.
@@ -11,8 +10,7 @@ from integrations.solvability.models.classifier import SolvabilityClassifier
def load_classifier(name: str) -> SolvabilityClassifier:
"""
Load a classifier by name.
"""Load a classifier by name.
Args:
name (str): The name of the classifier to load.
@@ -31,8 +29,7 @@ def load_classifier(name: str) -> SolvabilityClassifier:
def available_classifiers() -> list[str]:
"""
List all available classifiers in the data directory.
"""List all available classifiers in the data directory.
Returns:
list[str]: A list of classifier names (without the .json extension).

View File

@@ -1,5 +1,4 @@
"""
Solvability Models Package
"""Solvability Models Package
This package contains the core machine learning models and components for predicting
the solvability of GitHub issues and similar technical problems.

View File

@@ -26,8 +26,7 @@ from openhands.core.config import LLMConfig
class SolvabilityClassifier(BaseModel):
"""
Machine learning pipeline for predicting the solvability of GitHub issues and similar problems.
"""Machine learning pipeline for predicting the solvability of GitHub issues and similar problems.
This classifier combines LLM-based feature extraction with traditional ML classification:
1. Uses a Featurizer to extract semantic boolean features from issue descriptions via LLM calls
@@ -87,9 +86,7 @@ class SolvabilityClassifier(BaseModel):
@model_validator(mode='after')
def validate_random_state(self) -> SolvabilityClassifier:
"""
Validate the random state configuration between this object and the classifier.
"""
"""Validate the random state configuration between this object and the classifier."""
# If both random states are set, they definitely need to agree.
if self.random_state is not None and self.classifier.random_state is not None:
if self.random_state != self.classifier.random_state:
@@ -104,9 +101,7 @@ class SolvabilityClassifier(BaseModel):
@property
def features_(self) -> pd.DataFrame:
"""
Get the features used by the classifier for the most recent inputs.
"""
"""Get the features used by the classifier for the most recent inputs."""
if 'features_' not in self._classifier_attrs:
raise ValueError(
'SolvabilityClassifier.transform() has not yet been called.'
@@ -115,9 +110,7 @@ class SolvabilityClassifier(BaseModel):
@property
def cost_(self) -> pd.DataFrame:
"""
Get the cost of the classifier for the most recent inputs.
"""
"""Get the cost of the classifier for the most recent inputs."""
if 'cost_' not in self._classifier_attrs:
raise ValueError(
'SolvabilityClassifier.transform() has not yet been called.'
@@ -126,9 +119,7 @@ class SolvabilityClassifier(BaseModel):
@property
def feature_importances_(self) -> np.ndarray:
"""
Get the feature importances for the most recent inputs.
"""
"""Get the feature importances for the most recent inputs."""
if 'feature_importances_' not in self._classifier_attrs:
raise ValueError(
'No SolvabilityClassifier methods that produce feature importances (.fit(), .predict_proba(), and '
@@ -138,9 +129,7 @@ class SolvabilityClassifier(BaseModel):
@property
def is_fitted(self) -> bool:
"""
Check if the classifier is fitted.
"""
"""Check if the classifier is fitted."""
try:
check_is_fitted(self.classifier)
return True
@@ -148,8 +137,7 @@ class SolvabilityClassifier(BaseModel):
return False
def transform(self, issues: pd.Series, llm_config: LLMConfig) -> pd.DataFrame:
"""
Transform the input issues using the featurizer to extract features.
"""Transform the input issues using the featurizer to extract features.
This method orchestrates the feature extraction pipeline:
1. Uses the featurizer to generate embeddings for all issues
@@ -183,8 +171,7 @@ class SolvabilityClassifier(BaseModel):
def fit(
self, issues: pd.Series, labels: pd.Series, llm_config: LLMConfig
) -> SolvabilityClassifier:
"""
Fit the classifier to the input issues and labels.
"""Fit the classifier to the input issues and labels.
Args:
issues: A pandas Series containing the issue descriptions.
@@ -208,8 +195,7 @@ class SolvabilityClassifier(BaseModel):
return self
def predict_proba(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
"""
Predict the solvability probabilities for the input issues.
"""Predict the solvability probabilities for the input issues.
Returns class probabilities where the second column represents the probability
of the issue being solvable (positive class).
@@ -243,8 +229,7 @@ class SolvabilityClassifier(BaseModel):
return scores # type: ignore[no-any-return]
def predict(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
"""
Predict the solvability of the input issues by returning binary labels.
"""Predict the solvability of the input issues by returning binary labels.
Uses a 0.5 probability threshold to convert probabilities to binary predictions.
@@ -266,8 +251,7 @@ class SolvabilityClassifier(BaseModel):
scores: np.ndarray,
labels: np.ndarray | None = None,
) -> np.ndarray:
"""
Calculate feature importance scores using the configured strategy.
"""Calculate feature importance scores using the configured strategy.
Different strategies provide different interpretations:
- SHAP: Shapley values indicating contribution to individual predictions
@@ -313,8 +297,7 @@ class SolvabilityClassifier(BaseModel):
)
def add_features(self, features: list[Feature]) -> SolvabilityClassifier:
"""
Add new features to the classifier's featurizer.
"""Add new features to the classifier's featurizer.
Note: Adding features after training requires retraining the classifier
since the feature space will have changed.
@@ -331,8 +314,7 @@ class SolvabilityClassifier(BaseModel):
return self
def forget_features(self, features: list[Feature]) -> SolvabilityClassifier:
"""
Remove features from the classifier's featurizer.
"""Remove features from the classifier's featurizer.
Note: Removing features after training requires retraining the classifier
since the feature space will have changed.
@@ -354,17 +336,13 @@ class SolvabilityClassifier(BaseModel):
@field_serializer('classifier')
@staticmethod
def _rfc_to_json(rfc: RandomForestClassifier) -> str:
"""
Convert a RandomForestClassifier to a JSON-compatible value (a string).
"""
"""Convert a RandomForestClassifier to a JSON-compatible value (a string)."""
return base64.b64encode(pickle.dumps(rfc)).decode('utf-8')
@field_validator('classifier', mode='before')
@staticmethod
def _json_to_rfc(value: str | RandomForestClassifier) -> RandomForestClassifier:
"""
Convert a JSON-compatible value (a string) back to a RandomForestClassifier.
"""
"""Convert a JSON-compatible value (a string) back to a RandomForestClassifier."""
if isinstance(value, RandomForestClassifier):
return value
@@ -383,8 +361,7 @@ class SolvabilityClassifier(BaseModel):
def solvability_report(
self, issue: str, llm_config: LLMConfig, **kwargs: Any
) -> SolvabilityReport:
"""
Generate a solvability report for the given issue.
"""Generate a solvability report for the given issue.
Args:
issue: The issue description for which to generate the report.
@@ -427,7 +404,5 @@ class SolvabilityClassifier(BaseModel):
def __call__(
self, issue: str, llm_config: LLMConfig, **kwargs: Any
) -> SolvabilityReport:
"""
Generate a solvability report for the given issue.
"""
"""Generate a solvability report for the given issue."""
return self.solvability_report(issue, llm_config=llm_config, **kwargs)

View File

@@ -10,8 +10,7 @@ from openhands.llm.llm import LLM
class Feature(BaseModel):
"""
Represents a single boolean feature that can be extracted from issue descriptions.
"""Represents a single boolean feature that can be extracted from issue descriptions.
Features are semantic properties of issues (e.g., "has_code_example", "requires_debugging")
that are evaluated by LLMs and used as input to the solvability classifier.
@@ -25,8 +24,7 @@ class Feature(BaseModel):
@property
def to_tool_description_field(self) -> dict[str, Any]:
"""
Convert this feature to a JSON schema field for LLM tool calling.
"""Convert this feature to a JSON schema field for LLM tool calling.
Returns:
dict: JSON schema field definition for this feature.
@@ -38,8 +36,7 @@ class Feature(BaseModel):
class EmbeddingDimension(BaseModel):
"""
Represents a single dimension (feature evaluation) within a feature embedding sample.
"""Represents a single dimension (feature evaluation) within a feature embedding sample.
Each dimension corresponds to one feature being evaluated as true/false for a given issue.
"""
@@ -60,8 +57,7 @@ Maps feature identifiers to their boolean evaluations.
class FeatureEmbedding(BaseModel):
"""
Represents the complete feature embedding for a single issue, including multiple samples
"""Represents the complete feature embedding for a single issue, including multiple samples
and associated metadata about the LLM calls used to generate it.
Multiple samples are collected to account for LLM variability and provide more robust
@@ -82,8 +78,7 @@ class FeatureEmbedding(BaseModel):
@property
def dimensions(self) -> list[str]:
"""
Get all unique feature identifiers present across all samples.
"""Get all unique feature identifiers present across all samples.
Returns:
list[str]: List of feature identifiers that appear in at least one sample.
@@ -94,8 +89,7 @@ class FeatureEmbedding(BaseModel):
return list(dims)
def coefficient(self, dimension: str) -> float | None:
"""
Calculate the average coefficient (0-1) for a specific feature dimension.
"""Calculate the average coefficient (0-1) for a specific feature dimension.
This computes the proportion of samples where the feature was evaluated as True,
providing a continuous feature value for the classifier.
@@ -117,8 +111,7 @@ class FeatureEmbedding(BaseModel):
return None
def to_row(self) -> dict[str, Any]:
"""
Convert the embedding to a flat dictionary suitable for DataFrame construction.
"""Convert the embedding to a flat dictionary suitable for DataFrame construction.
Returns:
dict[str, Any]: Dictionary with metadata fields and feature coefficients.
@@ -131,8 +124,7 @@ class FeatureEmbedding(BaseModel):
}
def sample_entropy(self) -> dict[str, float]:
"""
Calculate the Shannon entropy of feature evaluations across samples.
"""Calculate the Shannon entropy of feature evaluations across samples.
Higher entropy indicates more variability in LLM responses for a feature,
which may suggest ambiguity in the feature definition or issue description.
@@ -162,8 +154,7 @@ class FeatureEmbedding(BaseModel):
class Featurizer(BaseModel):
"""
Orchestrates LLM-based feature extraction from issue descriptions.
"""Orchestrates LLM-based feature extraction from issue descriptions.
The Featurizer uses structured LLM tool calling to evaluate boolean features
for issue descriptions. It handles prompt construction, tool schema generation,
@@ -180,8 +171,7 @@ class Featurizer(BaseModel):
"""List of features to extract from each issue description."""
def system_message(self) -> dict[str, Any]:
"""
Construct the system message for LLM conversations.
"""Construct the system message for LLM conversations.
Returns:
dict[str, Any]: System message dictionary for LLM API calls.
@@ -194,8 +184,7 @@ class Featurizer(BaseModel):
def user_message(
self, issue_description: str, set_cache: bool = True
) -> dict[str, Any]:
"""
Construct the user message containing the issue description.
"""Construct the user message containing the issue description.
Args:
issue_description: The description of the issue to analyze.
@@ -215,8 +204,7 @@ class Featurizer(BaseModel):
@property
def tool_choice(self) -> dict[str, Any]:
"""
Get the tool choice configuration for forcing LLM to use the featurizer tool.
"""Get the tool choice configuration for forcing LLM to use the featurizer tool.
Returns:
dict[str, Any]: Tool choice configuration for LLM API calls.
@@ -228,8 +216,7 @@ class Featurizer(BaseModel):
@property
def tool_description(self) -> dict[str, Any]:
"""
Generate the tool schema for the featurizer function.
"""Generate the tool schema for the featurizer function.
Creates a JSON schema that describes the featurizer tool with all configured
features as boolean parameters.
@@ -259,8 +246,7 @@ class Featurizer(BaseModel):
temperature: float = 1.0,
samples: int = 10,
) -> FeatureEmbedding:
"""
Generate a feature embedding for a single issue description.
"""Generate a feature embedding for a single issue description.
Makes multiple LLM calls to collect samples and reduce variance in feature evaluations.
Each call uses tool calling to extract structured boolean feature values.
@@ -322,8 +308,7 @@ class Featurizer(BaseModel):
temperature: float = 1.0,
samples: int = 10,
) -> list[FeatureEmbedding]:
"""
Generate embeddings for a batch of issue descriptions using concurrent processing.
"""Generate embeddings for a batch of issue descriptions using concurrent processing.
Processes multiple issues in parallel to improve throughput while maintaining
result ordering.
@@ -359,8 +344,7 @@ class Featurizer(BaseModel):
return results
def feature_identifiers(self) -> list[str]:
"""
Get the identifiers of all configured features.
"""Get the identifiers of all configured features.
Returns:
list[str]: List of feature identifiers in the order they were defined.

View File

@@ -2,8 +2,7 @@ from enum import Enum
class ImportanceStrategy(str, Enum):
"""
Strategy to use for calculating feature importances, which are used to estimate the predictive power of each feature
"""Strategy to use for calculating feature importances, which are used to estimate the predictive power of each feature
in training loops and explanations.
"""

View File

@@ -6,8 +6,7 @@ from pydantic import BaseModel, Field
class SolvabilityReport(BaseModel):
"""
Comprehensive report containing solvability predictions and analysis for a single issue.
"""Comprehensive report containing solvability predictions and analysis for a single issue.
This report includes the solvability score, extracted feature values, feature importance analysis,
cost metrics (tokens and latency), and metadata about the prediction process. It serves as the

View File

@@ -39,13 +39,13 @@ class ResolverViewInterface(SummaryExtractionTracker):
raw_payload: dict
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
"Instructions passed when conversation is first initialized"
"""Instructions passed when conversation is first initialized"""
raise NotImplementedError()
async def create_new_conversation(self, jinja_env: Environment, token: str):
"Create a new conversation"
"""Create a new conversation"""
raise NotImplementedError()
def get_callback_id(self) -> str:
"Unique callback id for subscribription made to EventStream for fetching agent summary"
"""Unique callback id for subscribription made to EventStream for fetching agent summary"""
raise NotImplementedError()

View File

@@ -215,9 +215,7 @@ def get_last_user_msg(event_store: EventStoreABC) -> list[MessageAction]:
def extract_summary_from_event_store(
event_store: EventStoreABC, conversation_id: str
) -> str:
"""
Get agent summary or alternative message depending on current AgentState
"""
"""Get agent summary or alternative message depending on current AgentState"""
conversation_link = CONVERSATION_URL.format(conversation_id)
summary_instruction = get_summary_instruction()
@@ -293,10 +291,7 @@ async def get_last_user_msg_from_conversation_manager(
async def extract_summary_from_conversation_manager(
conversation_manager: ConversationManager, conversation_id: str
) -> str:
"""
Get agent summary or alternative message depending on current AgentState
"""
"""Get agent summary or alternative message depending on current AgentState"""
event_store = await get_event_store_from_conversation_manager(
conversation_manager, conversation_id
)
@@ -305,8 +300,7 @@ async def extract_summary_from_conversation_manager(
def append_conversation_footer(message: str, conversation_id: str) -> str:
"""
Append a small footer with the conversation URL to a message.
"""Append a small footer with the conversation URL to a message.
Args:
message: The original message content
@@ -321,14 +315,12 @@ def append_conversation_footer(message: str, conversation_id: str) -> str:
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
"""
Store repositories in DB and create user-repository mappings
"""Store repositories in DB and create user-repository mappings
Args:
repos: List of Repository objects to store
user_id: User ID associated with these repositories
"""
# Convert Repository objects to StoredRepository objects
# Convert Repository objects to UserRepositoryMap objects
stored_repos = []
@@ -366,9 +358,9 @@ async def store_repositories_in_db(repos: list[Repository], user_id: str) -> Non
def infer_repo_from_message(user_msg: str) -> list[str]:
"""
Extract all repository names in the format 'owner/repo' from various Git provider URLs
"""Extract all repository names in the format 'owner/repo' from various Git provider URLs
and direct mentions in text. Supports GitHub, GitLab, and BitBucket.
Args:
user_msg: Input message that may contain repository references
Returns:
@@ -451,10 +443,10 @@ def filter_potential_repos_by_user_msg(
def markdown_to_jira_markup(markdown_text: str) -> str:
"""
Convert markdown text to Jira Wiki Markup format.
"""Convert markdown text to Jira Wiki Markup format.
This function handles common markdown elements and converts them to their
Jira Wiki Markup equivalents. It's designed to be exception-safe.
Args:
markdown_text: The markdown text to convert
Returns:

View File

@@ -22,8 +22,7 @@ depends_on = None
def upgrade():
"""
Create maintenance tasks for all users whose user_version is less than
"""Create maintenance tasks for all users whose user_version is less than
the current version.
This replaces the functionality of the removed admin maintenance endpoint.
@@ -89,8 +88,7 @@ def upgrade():
def downgrade():
"""
No downgrade operation needed as we're just creating tasks.
"""No downgrade operation needed as we're just creating tasks.
The tasks themselves will be processed and completed.
If needed, we could delete tasks with this processor type, but that's not necessary

View File

@@ -19,7 +19,10 @@ from server.auth.constants import ( # noqa: E402
from server.constants import PERMITTED_CORS_ORIGINS # noqa: E402
from server.logger import logger # noqa: E402
from server.metrics import metrics_app # noqa: E402
from server.middleware import SetAuthCookieMiddleware # noqa: E402
from server.middleware import ( # noqa: E402
LLMSettingsMiddleware,
SetAuthCookieMiddleware,
)
from server.rate_limit import setup_rate_limit_handler # noqa: E402
from server.routes.api_keys import api_router as api_keys_router # noqa: E402
from server.routes.auth import api_router, oauth_router # noqa: E402
@@ -105,6 +108,7 @@ base_app.add_middleware(
allow_headers=['*'],
)
base_app.add_middleware(CacheControlMiddleware)
base_app.middleware('http')(LLMSettingsMiddleware())
base_app.middleware('http')(SetAuthCookieMiddleware())
base_app.mount('/', SPAStaticFiles(directory=directory, html=True), name='dist')

View File

@@ -31,6 +31,7 @@ class GoogleSheetsClient:
self, spreadsheet_id: str, range_name: str
) -> Optional[List[str]]:
"""Get usernames from cache if available and not expired.
Args:
spreadsheet_id: The ID of the Google Sheet
range_name: The A1 notation of the range to fetch
@@ -56,6 +57,7 @@ class GoogleSheetsClient:
self, spreadsheet_id: str, range_name: str, usernames: List[str]
) -> None:
"""Update cache with new usernames and current timestamp.
Args:
spreadsheet_id: The ID of the Google Sheet
range_name: The A1 notation of the range to fetch
@@ -67,6 +69,7 @@ class GoogleSheetsClient:
def get_usernames(self, spreadsheet_id: str, range_name: str = 'A:A') -> List[str]:
"""Get list of usernames from specified Google Sheet.
Uses cached data if available and less than 15 seconds old.
Args:
spreadsheet_id: The ID of the Google Sheet
range_name: The A1 notation of the range to fetch

View File

@@ -483,8 +483,7 @@ class ClusteredConversationManager(StandaloneConversationManager):
await pipe.execute()
async def _disconnect_from_stopped(self):
"""
Handle connections to conversations that have stopped unexpectedly.
"""Handle connections to conversations that have stopped unexpectedly.
This method detects when a local connection is pointing to a conversation
that was running on another server that has crashed or been terminated

View File

@@ -70,8 +70,7 @@ PERMITTED_CORS_ORIGINS = [
def build_litellm_proxy_model_path(model_name: str) -> str:
"""
Build the LiteLLM proxy model path based on environment and model name.
"""Build the LiteLLM proxy model path based on environment and model name.
This utility constructs the full model path for LiteLLM proxy based on:
- Environment type (staging vs prod)
@@ -83,7 +82,6 @@ def build_litellm_proxy_model_path(model_name: str) -> str:
Returns:
The full LiteLLM proxy model path (e.g., 'litellm_proxy/prod/claude-3-7-sonnet-20250219')
"""
if 'prod' in model_name or 'litellm' in model_name or 'proxy' in model_name:
raise ValueError("Only include model name, don't include prefix")
@@ -96,8 +94,7 @@ def build_litellm_proxy_model_path(model_name: str) -> str:
def get_default_litellm_model():
"""
Construct proxy for litellm model based on user settings and environment type (staging vs prod)
"""Construct proxy for litellm model based on user settings and environment type (staging vs prod)
if not set explicitly
"""
if LITELLM_DEFAULT_MODEL:

View File

@@ -25,8 +25,7 @@ from openhands.server.shared import conversation_manager
class GithubCallbackProcessor(ConversationCallbackProcessor):
"""
Processor for sending conversation summaries to GitHub.
"""Processor for sending conversation summaries to GitHub.
This processor is used to send summaries of conversations to GitHub issues/PRs
when agent state changes occur.
@@ -36,8 +35,7 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
send_summary_instruction: bool = True
async def _send_message_to_github(self, message: str) -> None:
"""
Send a message to GitHub.
"""Send a message to GitHub.
Args:
message: The message content to send to GitHub
@@ -68,8 +66,7 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
callback: ConversationCallback,
observation: AgentStateChangedObservation,
) -> None:
"""
Process a conversation event by sending a summary to GitHub.
"""Process a conversation event by sending a summary to GitHub.
Args:
callback: The conversation callback

View File

@@ -28,8 +28,7 @@ gitlab_manager = GitlabManager(token_manager)
class GitlabCallbackProcessor(ConversationCallbackProcessor):
"""
Processor for sending conversation summaries to GitLab.
"""Processor for sending conversation summaries to GitLab.
This processor is used to send summaries of conversations to GitLab
when agent state changes occur.
@@ -39,8 +38,7 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
send_summary_instruction: bool = True
async def _send_message_to_gitlab(self, message: str) -> None:
"""
Send a message to GitLab.
"""Send a message to GitLab.
Args:
message: The message content to send to GitLab
@@ -67,8 +65,7 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
callback: ConversationCallback,
observation: AgentStateChangedObservation,
) -> None:
"""
Process a conversation event by sending a summary to GitLab.
"""Process a conversation event by sending a summary to GitLab.
Args:
callback: The conversation callback

View File

@@ -26,8 +26,7 @@ integration_store = jira_manager.integration_store
class JiraCallbackProcessor(ConversationCallbackProcessor):
"""
Processor for sending conversation summaries to Jira.
"""Processor for sending conversation summaries to Jira.
This processor is used to send summaries of conversations to Jira issues
when agent state changes occur.
@@ -37,8 +36,7 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
workspace_name: str
async def _send_comment_to_jira(self, message: str) -> None:
"""
Send a comment to Jira issue.
"""Send a comment to Jira issue.
Args:
message: The message content to send to Jira
@@ -79,8 +77,7 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
callback: ConversationCallback,
observation: AgentStateChangedObservation,
) -> None:
"""
Process a conversation event by sending a summary to Jira.
"""Process a conversation event by sending a summary to Jira.
Args:
callback: The conversation callback

View File

@@ -25,8 +25,7 @@ jira_dc_manager = JiraDcManager(token_manager)
class JiraDcCallbackProcessor(ConversationCallbackProcessor):
"""
Processor for sending conversation summaries to Jira DC.
"""Processor for sending conversation summaries to Jira DC.
This processor is used to send summaries of conversations to Jira DC issues
when agent state changes occur.
@@ -37,8 +36,7 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
base_api_url: str
async def _send_comment_to_jira_dc(self, message: str) -> None:
"""
Send a comment to Jira DC issue.
"""Send a comment to Jira DC issue.
Args:
message: The message content to send to Jira DC
@@ -80,8 +78,7 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
callback: ConversationCallback,
observation: AgentStateChangedObservation,
) -> None:
"""
Process a conversation event by sending a summary to Jira DC.
"""Process a conversation event by sending a summary to Jira DC.
Args:
callback: The conversation callback

View File

@@ -24,8 +24,7 @@ linear_manager = LinearManager(token_manager)
class LinearCallbackProcessor(ConversationCallbackProcessor):
"""
Processor for sending conversation summaries to Linear.
"""Processor for sending conversation summaries to Linear.
This processor is used to send summaries of conversations to Linear issues
when agent state changes occur.
@@ -36,8 +35,7 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
workspace_name: str
async def _send_comment_to_linear(self, message: str) -> None:
"""
Send a comment to Linear issue.
"""Send a comment to Linear issue.
Args:
message: The message content to send to Linear
@@ -79,8 +77,7 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
callback: ConversationCallback,
observation: AgentStateChangedObservation,
) -> None:
"""
Process a conversation event by sending a summary to Linear.
"""Process a conversation event by sending a summary to Linear.
Args:
callback: The conversation callback

View File

@@ -26,8 +26,7 @@ slack_manager = SlackManager(token_manager)
class SlackCallbackProcessor(ConversationCallbackProcessor):
"""
Processor for sending conversation summaries to Slack.
"""Processor for sending conversation summaries to Slack.
This processor is used to send summaries of conversations to Slack channels
when agent state changes occur.
@@ -41,8 +40,7 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
last_user_msg_id: int | None = None
async def _send_message_to_slack(self, message: str) -> None:
"""
Send a message to Slack using the conversation_manager's send_to_event_stream method.
"""Send a message to Slack using the conversation_manager's send_to_event_stream method.
Args:
message: The message content to send to Slack
@@ -83,8 +81,7 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
callback: ConversationCallback,
observation: AgentStateChangedObservation,
) -> None:
"""
Process a conversation event by sending a summary to Slack.
"""Process a conversation event by sending a summary to Slack.
Args:
conversation_id: The ID of the conversation to process

View File

@@ -33,8 +33,7 @@ class LegacyCacheEntry:
@dataclass
class LegacyConversationManager(ConversationManager):
"""
Conversation manager for use while migrating - since existing conversations are not nested!
"""Conversation manager for use while migrating - since existing conversations are not nested!
Separate class from SaasNestedConversationManager so it can be easliy removed in a few weeks.
(As of 2025-07-23)
"""
@@ -270,8 +269,7 @@ class LegacyConversationManager(ConversationManager):
del self._legacy_cache[key]
async def should_start_in_legacy_mode(self, conversation_id: str) -> bool:
"""
Check if a conversation should run in legacy mode by directly checking the runtime.
"""Check if a conversation should run in legacy mode by directly checking the runtime.
The /list method does not include stopped conversations even though the PVC for these
may not yet have been deleted, so we need to check /sessions/{session_id} directly.
"""
@@ -295,8 +293,7 @@ class LegacyConversationManager(ConversationManager):
return is_legacy
def is_legacy_runtime(self, runtime: dict | None) -> bool:
"""
Determine if a runtime is a legacy runtime based on its command.
"""Determine if a runtime is a legacy runtime based on its command.
Args:
runtime: The runtime dictionary or None if not found

View File

@@ -59,11 +59,9 @@ def setup_json_logger(
level: str = LOG_LEVEL,
_out: TextIO = sys.stdout,
) -> None:
"""
Configure logger instance to output json for Google Cloud.
"""Configure logger instance to output json for Google Cloud.
Existing filters should stay in place for sensitive content.
"""
# Remove existing handlers to avoid duplicate logs
for handler in logger.handlers[:]:
logger.removeHandler(handler)
@@ -84,8 +82,7 @@ def setup_json_logger(
def setup_all_loggers():
"""
Setup JSON logging for all libraries that may be logging.
"""Setup JSON logging for all libraries that may be logging.
Leave OpenHands alone since it's already configured.
"""
if LOG_JSON:

View File

@@ -13,8 +13,7 @@ from openhands.core.config import load_openhands_config
class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
"""
Processor for upgrading user settings to the current version.
"""Processor for upgrading user settings to the current version.
This processor takes a list of user IDs and upgrades any users
whose user_version is less than CURRENT_USER_SETTINGS_VERSION.
@@ -23,8 +22,7 @@ class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
user_ids: List[str]
async def __call__(self, task: MaintenanceTask) -> dict:
"""
Process user version upgrades for the specified user IDs.
"""Process user version upgrades for the specified user IDs.
Args:
task: The maintenance task being processed

View File

@@ -27,8 +27,7 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
def create_default_mcp_server_config(
host: str, config: 'OpenHandsConfig', user_id: str | None = None
) -> tuple[MCPSHTTPServerConfig | None, list[MCPStdioServerConfig]]:
"""
Create a default MCP server configuration.
"""Create a default MCP server configuration.
Args:
host: Host string
@@ -36,7 +35,6 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
Returns:
A tuple containing the default SSE server configuration and a list of MCP stdio server configurations
"""
api_key_store = ApiKeyStore.get_instance()
if user_id:
api_key = api_key_store.retrieve_mcp_api_key(user_id)

View File

@@ -33,8 +33,7 @@ def metrics_app() -> Callable:
metrics_callable = make_asgi_app()
async def wrapped_handler(scope, receive, send):
"""
Call _update_metrics before serving Prometheus metrics endpoint.
"""Call _update_metrics before serving Prometheus metrics endpoint.
Not wrapped in a `try`, failing would make metrics endpoint unavailable.
"""
await _update_metrics()

View File

@@ -1,7 +1,8 @@
from datetime import UTC, datetime
from typing import Callable
import jwt
from fastapi import Request, Response, status
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
from pydantic import SecretStr
from server.auth.auth_error import (
@@ -12,21 +13,25 @@ from server.auth.auth_error import (
)
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
from server.auth.saas_user_auth import SaasUserAuth, token_manager
from server.constants import get_default_litellm_model
from server.routes.auth import (
get_cookie_domain,
get_cookie_samesite,
set_response_cookie,
)
from storage.database import session_maker
from storage.subscription_access import SubscriptionAccess
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth.user_auth import AuthType, get_user_auth
from openhands.server.utils import config
from openhands.storage.data_models.settings import Settings
class SetAuthCookieMiddleware:
"""
Update the auth cookie with the current authentication state if it was refreshed before sending response to user.
Deleting invalid cookies is handled by CookieError using FastAPIs standard error handling mechanism
"""Update the auth cookie with the current authentication state if it was refreshed before sending response to user.
Deleting invalid cookies is handled by CookieError using FastAPIs standard error handling mechanism.
"""
async def __call__(self, request: Request, call_next: Callable):
@@ -172,3 +177,247 @@ class SetAuthCookieMiddleware:
await token_manager.logout(user_auth.refresh_token.get_secret_value())
except Exception:
logger.debug('Error logging out')
class LLMSettingsMiddleware:
"""Middleware to validate LLM settings access for enterprise users.
Intercepts POST requests to /api/settings and validates that non-pro users
cannot modify LLM-related settings.
"""
async def __call__(self, request: Request, call_next: Callable):
try:
logger.warning(
f'LLM middleware called for {request.method} {request.url.path}'
)
# Check if this is a POST request to /api/settings
if request.method == 'POST' and request.url.path == '/api/settings':
logger.warning('LLM middleware intercepting POST /api/settings request')
await self._validate_llm_settings_request(request)
# Continue with the request
response: Response = await call_next(request)
return response
except HTTPException:
# Re-raise HTTPException (our 403 response)
raise
except Exception as e:
logger.warning(f'Error in LLM settings middleware: {e}')
# Let other errors pass through to be handled by the route
fallback_response: Response = await call_next(request)
return fallback_response
async def _validate_llm_settings_request(self, request: Request) -> None:
"""Validate LLM settings access for the current request."""
try:
logger.info(
f"LLM settings middleware intercepting POST /api/settings from {request.client.host if request.client else 'unknown'}"
)
# Get user authentication - this will trigger authentication if not already done
try:
user_auth = await get_user_auth(request)
except Exception as e:
logger.info(f'No valid user auth found ({e}), letting route handle request')
return # No user auth, let the route handle it
user_id = await user_auth.get_user_id()
if not user_id:
logger.info('No user ID found, letting route handle request')
return # No user ID, let the route handle it
logger.info(f'Processing settings request for user: {user_id}')
# Parse the request JSON to get new settings
try:
settings_data = await request.json()
logger.info(f'Parsed settings data keys: {list(settings_data.keys())}')
except Exception as e:
logger.warning(f'Invalid JSON in request body: {e}')
return # Invalid JSON, let the route handle it
# Convert to Settings object for validation
try:
new_settings = Settings(**settings_data)
logger.info('Successfully created Settings object from request data')
except Exception as e:
logger.warning(f'Invalid settings format: {e}')
return # Invalid settings format, let the route handle it
# Validate LLM settings access by comparing new settings against SaaS defaults
await validate_llm_settings_access(user_id, new_settings)
logger.info(f'LLM settings validation passed for user {user_id}')
except HTTPException as e:
logger.warning(
f'LLM settings validation failed: HTTP {e.status_code} - {e.detail}'
)
# Re-raise our 403 response
raise
except Exception as e:
logger.warning(f'Unexpected error validating LLM settings request: {e}')
# Let other errors pass through
def _get_saas_default_settings() -> Settings:
"""Get the default SaaS settings for comparison."""
return Settings(
language='en',
agent='CodeActAgent',
enable_proactive_conversation_starters=True,
enable_default_condenser=True,
condenser_max_size=120,
llm_model=get_default_litellm_model(), # litellm_proxy/prod/claude-sonnet-4-20250514
confirmation_mode=False,
security_analyzer='llm',
# Note: llm_api_key and llm_base_url are auto-provisioned for SaaS users,
# so we don't include them in defaults - any custom values are changes
)
def has_llm_settings_changes(user_settings: Settings, saas_defaults: Settings) -> bool:
"""Check if user settings contain changes to LLM-related settings from SaaS defaults."""
logger.info(
f"Checking LLM settings changes - User settings: {user_settings.model_dump(exclude={'secrets_store'})}"
)
logger.info(
f"Checking LLM settings changes - SaaS defaults: {saas_defaults.model_dump(exclude={'secrets_store'})}"
)
# Core LLM settings - any custom values are changes since SaaS auto-provisions these
if (
user_settings.llm_model is not None
and user_settings.llm_model != saas_defaults.llm_model
):
logger.warning(
f"LLM model change detected: user='{user_settings.llm_model}' vs default='{saas_defaults.llm_model}'"
)
return True
if user_settings.llm_api_key is not None:
# Any custom API key is a change (SaaS users get auto-provisioned keys)
logger.warning(
f'LLM API key change detected: user has custom key (length={len(user_settings.llm_api_key.get_secret_value()) if user_settings.llm_api_key else 0})'
)
return True
if user_settings.llm_base_url is not None and user_settings.llm_base_url != '':
# Any non-empty base URL is a change (SaaS users get auto-provisioned URL)
logger.warning(
f"LLM base URL change detected: user='{user_settings.llm_base_url}' (non-empty)"
)
return True
# LLM-related configuration settings
if user_settings.agent is not None and user_settings.agent != saas_defaults.agent:
logger.warning(
f"Agent change detected: user='{user_settings.agent}' vs default='{saas_defaults.agent}'"
)
return True
if (
user_settings.confirmation_mode is not None
and user_settings.confirmation_mode != saas_defaults.confirmation_mode
):
logger.warning(
f'Confirmation mode change detected: user={user_settings.confirmation_mode} vs default={saas_defaults.confirmation_mode}'
)
return True
if (
user_settings.security_analyzer is not None
and user_settings.security_analyzer != saas_defaults.security_analyzer
and user_settings.security_analyzer != ''
): # Handle empty string as None
logger.warning(
f"Security analyzer change detected: user='{user_settings.security_analyzer}' vs default='{saas_defaults.security_analyzer}'"
)
return True
if user_settings.max_budget_per_task is not None:
logger.warning(
f'Max budget per task change detected: user={user_settings.max_budget_per_task}'
)
return True
if user_settings.max_iterations is not None:
logger.warning(
f'Max iterations change detected: user={user_settings.max_iterations}'
)
return True
# Memory/context management settings
if user_settings.enable_default_condenser != saas_defaults.enable_default_condenser:
logger.warning(
f'Enable default condenser change detected: user={user_settings.enable_default_condenser} vs default={saas_defaults.enable_default_condenser}'
)
return True
if (
user_settings.condenser_max_size is not None
and user_settings.condenser_max_size != saas_defaults.condenser_max_size
):
logger.warning(
f'Condenser max size change detected: user={user_settings.condenser_max_size} vs default={saas_defaults.condenser_max_size}'
)
return True
logger.info('No LLM settings changes detected')
return False
def _has_active_subscription(user_id: str) -> bool:
"""Check if user has an active subscription (pro user)."""
with session_maker() as session:
now = datetime.now(UTC)
logger.info(f'Checking subscription for user {user_id} at time {now}')
subscription_access = (
session.query(SubscriptionAccess)
.filter(SubscriptionAccess.status == 'ACTIVE')
.filter(SubscriptionAccess.user_id == user_id)
.filter(SubscriptionAccess.start_at <= now)
.filter(SubscriptionAccess.end_at >= now)
.first()
)
if subscription_access:
logger.info(
f'Found active subscription for user {user_id}: starts={subscription_access.start_at}, ends={subscription_access.end_at}'
)
else:
logger.info(f'No active subscription found for user {user_id}')
return subscription_access is not None
async def validate_llm_settings_access(
user_id: str, user_settings: Settings, saas_defaults: Settings | None = None
) -> None:
"""Validate that user has permission to change LLM settings.
Raises HTTPException with 403 status if non-pro user tries to change LLM settings.
"""
if saas_defaults is None:
saas_defaults = _get_saas_default_settings()
logger.info(f'Validating LLM settings access for user: {user_id}')
# Check if user is trying to change LLM settings
if has_llm_settings_changes(user_settings, saas_defaults):
logger.warning(f'User {user_id} attempting to change LLM settings')
# Check if user has active subscription (is pro user)
has_subscription = _has_active_subscription(user_id)
logger.info(
f"User {user_id} subscription status: {'active' if has_subscription else 'none'}"
)
if not has_subscription:
logger.warning(
f'Blocking non-pro user {user_id} from changing LLM settings'
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='LLM settings can only be modified by pro users',
)
else:
logger.info(f'Allowing pro user {user_id} to change LLM settings')
else:
logger.info(f'User {user_id} making non-LLM settings changes only - allowing')

View File

@@ -1,5 +1,4 @@
"""
Usage:
"""Usage:
Call setup_rate_limit_handler on your FastAPI app to add the exception handler
@@ -23,9 +22,7 @@ from openhands.core.logger import openhands_logger as logger
def setup_rate_limit_handler(app: Starlette):
"""
Add exception handler that
"""
"""Add exception handler that"""
app.add_exception_handler(RateLimitException, _rate_limit_exceeded_handler)
@@ -56,8 +53,7 @@ class RateLimiter:
self.limit_items = limits.parse_many(windows)
async def hit(self, namespace: str, key: str):
"""
Raises RateLimitException when limit is hit.
"""Raises RateLimitException when limit is hit.
Logs and swallows exceptions and logs if lookup fails.
"""
for lim in self.limit_items:
@@ -80,9 +76,7 @@ class RateLimiter:
async def _get_stats_as_result(
self, lim: limits.RateLimitItem, namespace: str, key: str
) -> RateLimitResult:
"""
Lookup rate limit window stats and return a RateLimitResult with the data needed for response headers.
"""
"""Lookup rate limit window stats and return a RateLimitResult with the data needed for response headers."""
stats: limits.WindowStats = await self.strategy.get_window_stats(
lim, namespace, key
)
@@ -97,8 +91,7 @@ class RateLimiter:
def create_redis_rate_limiter(windows: str) -> RateLimiter:
"""
Create a RateLimiter with the Redis backend and "Fixed Window" strategy.
"""Create a RateLimiter with the Redis backend and "Fixed Window" strategy.
windows arg example: "10/second; 100/minute"
"""
backend = limits.aio.storage.RedisStorage(f'async+{get_redis_authed_url()}')
@@ -107,9 +100,7 @@ def create_redis_rate_limiter(windows: str) -> RateLimiter:
class RateLimitException(HTTPException):
"""
exception raised when a rate limit is hit.
"""
"""exception raised when a rate limit is hit."""
result: RateLimitResult
@@ -121,9 +112,7 @@ class RateLimitException(HTTPException):
def _rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
"""
Build a simple JSON response that includes the details of the rate limit that was hit.
"""
"""Build a simple JSON response that includes the details of the rate limit that was hit."""
logger.info(exc.__class__.__name__)
if isinstance(exc, RateLimitException):
response = JSONResponse(

View File

@@ -17,8 +17,7 @@ ADD_DEBUGGING_ROUTES = os.environ.get('ADD_DEBUGGING_ROUTES') in ('1', 'true')
def add_debugging_routes(api: FastAPI):
"""
# HERE BE DRAGONS!
"""# HERE BE DRAGONS!
Chaos scripts for debugging and stress testing the system.
This module contains endpoints that deliberately stress test and potentially break
@@ -31,7 +30,6 @@ def add_debugging_routes(api: FastAPI):
- Testing async vs sync database access patterns
- Simulating event loop blocking
"""
if not ADD_DEBUGGING_ROUTES:
return
@@ -39,8 +37,7 @@ def add_debugging_routes(api: FastAPI):
@chaos_router.get('/pool-stats')
def pool_stats() -> dict[str, int]:
"""
Returns current database connection pool statistics.
"""Returns current database connection pool statistics.
This endpoint provides real-time metrics about the SQLAlchemy connection pool:
- checked_in: Number of connections currently available in the pool
@@ -55,8 +52,7 @@ def add_debugging_routes(api: FastAPI):
@chaos_router.get('/test-db')
def test_db(num_tests: int = 10, delay: int = 1) -> str:
"""
Stress tests the database connection pool using multiple threads.
"""Stress tests the database connection pool using multiple threads.
Creates multiple threads that each open a database connection, perform a query,
hold the connection for the specified delay, and then release it.
@@ -77,8 +73,7 @@ def add_debugging_routes(api: FastAPI):
@chaos_router.get('/a-test-db')
async def a_chaos_monkey(num_tests: int = 10, delay: int = 1) -> str:
"""
Stress tests the async database connection pool.
"""Stress tests the async database connection pool.
Similar to /test-db but uses async connections and coroutines instead of threads.
This endpoint helps compare the behavior of async vs sync connection pools
@@ -93,8 +88,7 @@ def add_debugging_routes(api: FastAPI):
@chaos_router.get('/lock-main-runloop')
async def lock_main_runloop(duration: int = 10) -> str:
"""
Deliberately blocks the main asyncio event loop.
"""Deliberately blocks the main asyncio event loop.
This endpoint uses a synchronous sleep operation in an async function,
which blocks the entire FastAPI server's event loop for the specified duration.
@@ -113,8 +107,7 @@ def add_debugging_routes(api: FastAPI):
def _db_check(delay: int):
"""
Executes a single request against the database with an artificial delay.
"""Executes a single request against the database with an artificial delay.
This helper function:
1. Opens a database connection from the pool
@@ -141,8 +134,7 @@ def _db_check(delay: int):
async def _a_db_check(delay: int):
"""
Executes a single async request against the database with an artificial delay.
"""Executes a single async request against the database with an artificial delay.
This is the async version of _db_check that:
1. Opens an async database connection from the pool

View File

@@ -73,8 +73,7 @@ class FeedbackRequest(BaseModel):
@router.post('/conversation', status_code=status.HTTP_201_CREATED)
async def submit_conversation_feedback(feedback: FeedbackRequest):
"""
Submit feedback for a conversation.
"""Submit feedback for a conversation.
This endpoint accepts a rating (1-5) and optional reason for the feedback.
The feedback is associated with a specific conversation and optionally a specific event.
@@ -108,8 +107,7 @@ async def submit_conversation_feedback(feedback: FeedbackRequest):
@router.get('/conversation/{conversation_id}/batch')
async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_user_id)):
"""
Get feedback for all events in a conversation.
"""Get feedback for all events in a conversation.
Returns feedback status for each event, including whether feedback exists
and if so, the rating and reason.

View File

@@ -16,8 +16,7 @@ GITHUB_PROXY_ENDPOINTS = bool(os.environ.get('GITHUB_PROXY_ENDPOINTS'))
def add_github_proxy_routes(app: FastAPI):
"""
Authentication endpoints for feature branches.
"""Authentication endpoints for feature branches.
# Requirements
* This should never be enabled in prod!

View File

@@ -23,14 +23,10 @@ AGENT_SESSION_START_HISTOGRAM = Histogram(
class SaaSMonitoringListener(MonitoringListener):
"""
Forward app signals to Prometheus.
"""
"""Forward app signals to Prometheus."""
def on_session_event(self, event: Event) -> None:
"""
Track metrics about events being added to a Session's EventStream.
"""
"""Track metrics about events being added to a Session's EventStream."""
if (
isinstance(event, AgentStateChangedObservation)
and event.agent_state == AgentState.ERROR
@@ -42,8 +38,7 @@ class SaaSMonitoringListener(MonitoringListener):
)
def on_agent_session_start(self, success: bool, duration: float) -> None:
"""
Track an agent session start.
"""Track an agent session start.
Success is true if startup completed without error.
Duration is start time in seconds observed by AgentSession.
"""
@@ -58,8 +53,7 @@ class SaaSMonitoringListener(MonitoringListener):
)
def on_create_conversation(self) -> None:
"""
Track the beginning of conversation creation.
"""Track the beginning of conversation creation.
Does not currently capture whether it succeed.
"""
CREATE_CONVERSATION_COUNT.inc()

View File

@@ -131,9 +131,7 @@ class SaasNestedConversationManager(ConversationManager):
async def get_running_agent_loops(
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
) -> set[str]:
"""
Get the running agent loops directly from the remote runtime.
"""
"""Get the running agent loops directly from the remote runtime."""
conversation_ids = await self._get_all_running_conversation_ids()
if filter_to_sids is not None:
@@ -482,10 +480,7 @@ class SaasNestedConversationManager(ConversationManager):
)
def _get_user_id_from_conversation(self, conversation_id: str) -> str:
"""
Get user_id from conversation_id.
"""
"""Get user_id from conversation_id."""
with session_maker() as session:
conversation_metadata = (
session.query(StoredConversationMetadata)

View File

@@ -34,8 +34,7 @@ file_store = get_file_store(config.file_store, config.file_store_path)
async def process_event(
user_id: str, conversation_id: str, subpath: str, content: dict
):
"""
Process a conversation event and invoke any registered callbacks.
"""Process a conversation event and invoke any registered callbacks.
Args:
user_id: The user ID associated with the conversation
@@ -72,8 +71,7 @@ async def process_event(
async def invoke_conversation_callbacks(
conversation_id: str, observation: AgentStateChangedObservation
):
"""
Load and invoke all active callbacks for a conversation.
"""Load and invoke all active callbacks for a conversation.
Args:
conversation_id: The conversation ID to process callbacks for
@@ -119,8 +117,7 @@ async def invoke_conversation_callbacks(
def update_conversation_metadata(conversation_id: str, content: dict):
"""
Update conversation metadata with new content.
"""Update conversation metadata with new content.
Args:
conversation_id: The conversation ID to update
@@ -159,8 +156,7 @@ def update_conversation_metadata(conversation_id: str, content: dict):
def register_callback_processor(
conversation_id: str, processor: ConversationCallbackProcessor
) -> int:
"""
Register a callback processor for a conversation.
"""Register a callback processor for a conversation.
Args:
conversation_id: The conversation ID to register the callback for
@@ -182,8 +178,7 @@ def register_callback_processor(
def update_active_working_seconds(
event_store: EventStore, conversation_id: str, user_id: str, file_store: FileStore
):
"""
Calculate and update the total active working seconds for a conversation.
"""Calculate and update the total active working seconds for a conversation.
This function reads all events for the conversation, looks for AgentStateChanged
observations, and calculates the total time spent in a running state.
@@ -263,8 +258,7 @@ def update_active_working_seconds(
def update_agent_state(user_id: str, conversation_id: str, content: bytes):
"""
Update agent state file for a conversation.
"""Update agent state file for a conversation.
Args:
user_id: The user ID associated with the conversation

View File

@@ -3,9 +3,7 @@ from storage.base import Base
class ApiKey(Base):
"""
Represents an API key for a user.
"""
"""Represents an API key for a user."""
__tablename__ = 'api_keys'
id = Column(Integer, primary_key=True, autoincrement=True)

View File

@@ -73,8 +73,7 @@ class AuthTokenStore:
]
| None = None,
) -> Dict[str, str | int] | None:
"""
Load authentication tokens from the database and refresh them if necessary.
"""Load authentication tokens from the database and refresh them if necessary.
This method retrieves the current authentication tokens for the user and checks if they have expired.
It uses the provided `check_expiration_and_refresh` function to determine if the tokens need

View File

@@ -1,6 +1,4 @@
"""
Unified SQLAlchemy declarative base for all models.
"""
"""Unified SQLAlchemy declarative base for all models."""
from sqlalchemy.orm import declarative_base

View File

@@ -5,8 +5,7 @@ from storage.base import Base
class BillingSession(Base): # type: ignore
"""
Represents a Stripe billing session for credit purchases.
"""Represents a Stripe billing session for credit purchases.
Tracks the status of payment transactions and associated user information.
"""

View File

@@ -15,8 +15,7 @@ from openhands.utils.import_utils import get_impl
class ConversationCallbackProcessor(BaseModel, ABC):
"""
Abstract base class for conversation callback processors.
"""Abstract base class for conversation callback processors.
Conversation processors are invoked when events occur in a conversation
to perform additional processing, notifications, or integrations.
@@ -35,8 +34,7 @@ class ConversationCallbackProcessor(BaseModel, ABC):
callback: ConversationCallback,
observation: AgentStateChangedObservation,
) -> None:
"""
Process a conversation event.
"""Process a conversation event.
Args:
conversation_id: The ID of the conversation to process
@@ -54,9 +52,7 @@ class CallbackStatus(Enum):
class ConversationCallback(Base): # type: ignore
"""
Model for storing conversation callbacks that process conversation events.
"""
"""Model for storing conversation callbacks that process conversation events."""
__tablename__ = 'conversation_callbacks'
@@ -85,8 +81,7 @@ class ConversationCallback(Base): # type: ignore
)
def get_processor(self) -> ConversationCallbackProcessor:
"""
Get the processor instance from the stored processor type and JSON data.
"""Get the processor instance from the stored processor type and JSON data.
Returns:
ConversationCallbackProcessor: The processor instance
@@ -99,8 +94,7 @@ class ConversationCallback(Base): # type: ignore
return processor
def set_processor(self, processor: ConversationCallbackProcessor) -> None:
"""
Set the processor instance, storing its type and JSON representation.
"""Set the processor instance, storing its type and JSON representation.
Args:
processor: The ConversationCallbackProcessor instance to store

View File

@@ -1,5 +1,4 @@
"""
Database model for experiment assignments.
"""Database model for experiment assignments.
This model tracks which experiments a conversation is assigned to and what variant
they received from PostHog feature flags.

View File

@@ -1,5 +1,4 @@
"""
Store for managing experiment assignments.
"""Store for managing experiment assignments.
This store handles creating and updating experiment assignments for conversations.
"""
@@ -20,8 +19,7 @@ class ExperimentAssignmentStore:
experiment_name: str,
variant: str,
) -> None:
"""
Update the variant for a specific experiment.
"""Update the variant for a specific experiment.
Args:
conversation_id: The conversation ID

View File

@@ -3,9 +3,7 @@ from storage.base import Base
class GithubAppInstallation(Base): # type: ignore
"""
Represents a Github App Installation with associated token.
"""
"""Represents a Github App Installation with associated token."""
__tablename__ = 'github_app_installations'
id = Column(Integer, primary_key=True, autoincrement=True)

View File

@@ -13,9 +13,7 @@ class WebhookStatus(IntEnum):
class GitlabWebhook(Base): # type: ignore
"""
Represents a Gitlab webhook configuration for a repository or group.
"""
"""Represents a Gitlab webhook configuration for a repository or group."""
__tablename__ = 'gitlab_webhook'
id = Column(Integer, primary_key=True, autoincrement=True)

View File

@@ -86,7 +86,6 @@ class GitlabWebhookStore:
Raises:
ValueError: If neither project_id nor group_id is provided, or if both are provided.
"""
resource_type, resource_id = GitlabWebhookStore.determine_resource_type(webhook)
async with self.a_session_maker() as session:
async with session.begin():
@@ -110,7 +109,6 @@ class GitlabWebhookStore:
Raises:
ValueError: If neither project_id nor group_id is provided, or if both are provided.
"""
resource_type, resource_id = GitlabWebhookStore.determine_resource_type(webhook)
logger.info(
@@ -184,7 +182,6 @@ class GitlabWebhookStore:
Returns:
List of GitlabWebhook objects that need processing
"""
async with self.a_session_maker() as session:
query = (
select(GitlabWebhook)
@@ -198,9 +195,7 @@ class GitlabWebhookStore:
return list(webhooks)
async def get_webhook_secret(self, webhook_uuid: str, user_id: str) -> str | None:
"""
Get's webhook secret given the webhook uuid and admin keycloak user id
"""
"""Get's webhook secret given the webhook uuid and admin keycloak user id"""
async with self.a_session_maker() as session:
query = (
select(GitlabWebhook)

View File

@@ -23,7 +23,6 @@ class JiraDcIntegrationStore:
status: str = 'active',
) -> JiraDcWorkspace:
"""Create a new Jira DC workspace with encrypted sensitive data."""
with session_maker() as session:
workspace = JiraDcWorkspace(
name=name.lower(),
@@ -83,7 +82,6 @@ class JiraDcIntegrationStore:
status: str = 'active',
) -> JiraDcUser:
"""Create a new Jira DC workspace link."""
jira_dc_user = JiraDcUser(
keycloak_user_id=keycloak_user_id,
jira_dc_user_id=jira_dc_user_id,
@@ -125,7 +123,6 @@ class JiraDcIntegrationStore:
self, keycloak_user_id: str
) -> Optional[JiraDcUser]:
"""Retrieve user by Keycloak user ID."""
with session_maker() as session:
return (
session.query(JiraDcUser)
@@ -184,7 +181,6 @@ class JiraDcIntegrationStore:
self, keycloak_user_id: str, status: str
) -> JiraDcUser:
"""Update the status of a Jira DC user mapping."""
with session_maker() as session:
user = (
session.query(JiraDcUser)

View File

@@ -24,7 +24,6 @@ class JiraIntegrationStore:
status: str = 'active',
) -> JiraWorkspace:
"""Create a new Jira workspace with encrypted sensitive data."""
workspace = JiraWorkspace(
name=name.lower(),
jira_cloud_id=jira_cloud_id,
@@ -91,7 +90,6 @@ class JiraIntegrationStore:
status: str = 'active',
) -> JiraUser:
"""Create a new Jira workspace link."""
jira_user = JiraUser(
keycloak_user_id=keycloak_user_id,
jira_user_id=jira_user_id,

View File

@@ -24,7 +24,6 @@ class LinearIntegrationStore:
status: str = 'active',
) -> LinearWorkspace:
"""Create a new Linear workspace with encrypted sensitive data."""
workspace = LinearWorkspace(
name=name.lower(),
linear_org_id=linear_org_id,

View File

@@ -15,8 +15,7 @@ from openhands.utils.import_utils import get_impl
class MaintenanceTaskProcessor(BaseModel, ABC):
"""
Abstract base class for maintenance task processors.
"""Abstract base class for maintenance task processors.
Maintenance processors are invoked to perform background maintenance
tasks such as upgrading user settings, cleaning up data, etc.
@@ -31,8 +30,7 @@ class MaintenanceTaskProcessor(BaseModel, ABC):
@abstractmethod
async def __call__(self, task: MaintenanceTask) -> dict:
"""
Process a maintenance task.
"""Process a maintenance task.
Args:
task: The maintenance task to process
@@ -53,9 +51,7 @@ class MaintenanceTaskStatus(Enum):
class MaintenanceTask(Base): # type: ignore
"""
Model for storing maintenance tasks that perform background operations.
"""
"""Model for storing maintenance tasks that perform background operations."""
__tablename__ = 'maintenance_tasks'
@@ -83,8 +79,7 @@ class MaintenanceTask(Base): # type: ignore
)
def get_processor(self) -> MaintenanceTaskProcessor:
"""
Get the processor instance from the stored processor type and JSON data.
"""Get the processor instance from the stored processor type and JSON data.
Returns:
MaintenanceTaskProcessor: The processor instance
@@ -97,8 +92,7 @@ class MaintenanceTask(Base): # type: ignore
return processor
def set_processor(self, processor: MaintenanceTaskProcessor) -> None:
"""
Set the processor instance, storing its type and JSON representation.
"""Set the processor instance, storing its type and JSON representation.
Args:
processor: The MaintenanceTaskProcessor instance to store

View File

@@ -13,9 +13,7 @@ from storage.base import Base
class OpenhandsPR(Base): # type: ignore
"""
Represents a pull request created by OpenHands.
"""
"""Represents a pull request created by OpenHands."""
__tablename__ = 'openhands_prs'
id = Column(Integer, Identity(), primary_key=True)

View File

@@ -15,9 +15,7 @@ class OpenhandsPRStore:
session_maker: sessionmaker
def insert_pr(self, pr: OpenhandsPR) -> None:
"""
Insert a new PR or delete and recreate if repo_id and pr_number already exist.
"""
"""Insert a new PR or delete and recreate if repo_id and pr_number already exist."""
with self.session_maker() as session:
# Check if PR already exists
existing_pr = (
@@ -39,8 +37,7 @@ class OpenhandsPRStore:
session.commit()
def increment_process_attempts(self, repo_id: str, pr_number: int) -> bool:
"""
Increment the process attempts counter for a PR.
"""Increment the process attempts counter for a PR.
Args:
repo_id: Repository identifier
@@ -75,8 +72,7 @@ class OpenhandsPRStore:
num_openhands_review_comments: int,
num_openhands_general_comments: int,
) -> bool:
"""
Update OpenHands statistics for a PR with row-level locking and timestamp validation.
"""Update OpenHands statistics for a PR with row-level locking and timestamp validation.
Args:
repo_id: Repository identifier
@@ -126,8 +122,7 @@ class OpenhandsPRStore:
def get_unprocessed_prs(
self, limit: int = 50, max_retries: int = 3
) -> list[OpenhandsPR]:
"""
Get unprocessed PR entries from the OpenhandsPR table.
"""Get unprocessed PR entries from the OpenhandsPR table.
Args:
limit: Maximum number of PRs to retrieve (default: 50)

View File

@@ -34,8 +34,7 @@ class ProactiveConversationStore:
pr_number: int,
get_all_workflows: Callable,
) -> WorkflowRunGroup | None:
"""
1. Get the workflow based on repo_id, pr_number, commit
"""1. Get the workflow based on repo_id, pr_number, commit
2. If the field doesn't exist
- Fetch the workflow statuses and store them
- Create a new record
@@ -45,7 +44,6 @@ class ProactiveConversationStore:
This method uses an explicit transaction with row-level locking to ensure
thread safety when multiple processes access the same database rows.
"""
should_send = False
provider_repo_id = self.get_repo_id(provider, repo_id)
@@ -131,14 +129,12 @@ class ProactiveConversationStore:
return final_workflow_group
async def clean_old_convos(self, older_than_minutes=30):
"""
Clean up proactive conversation records that are older than the specified time.
"""Clean up proactive conversation records that are older than the specified time.
Args:
older_than_minutes: Number of minutes. Records older than this will be deleted.
Defaults to 30 minutes.
"""
# Calculate the cutoff time (current time - older_than_minutes)
cutoff_time = datetime.now(UTC) - timedelta(minutes=older_than_minutes)

View File

@@ -15,8 +15,7 @@ class RepositoryStore:
config: OpenHandsConfig
def store_projects(self, repositories: list[StoredRepository]) -> None:
"""
Store repositories in database
"""Store repositories in database
1. Make sure to store repositories if its ID doesn't exist
2. If repository ID already exists, make sure to only update the repo is_public and repo_name fields

View File

@@ -14,8 +14,7 @@ class SaasConversationValidator(ConversationValidator):
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
async def _validate_api_key(self, api_key: str) -> str | None:
"""
Validate an API key and return the user_id and github_user_id if valid.
"""Validate an API key and return the user_id and github_user_id if valid.
Args:
api_key: The API key to validate
@@ -49,8 +48,7 @@ class SaasConversationValidator(ConversationValidator):
async def _validate_conversation_access(
self, conversation_id: str, user_id: str
) -> bool:
"""
Validate that the user has access to the conversation.
"""Validate that the user has access to the conversation.
Args:
conversation_id: The ID of the conversation
@@ -81,8 +79,7 @@ class SaasConversationValidator(ConversationValidator):
cookies_str: str,
authorization_header: str | None = None,
) -> str | None:
"""
Validate the conversation access using either an API key from the Authorization header
"""Validate the conversation access using either an API key from the Authorization header
or a keycloak_auth cookie.
Args:

View File

@@ -14,8 +14,7 @@ class SlackConversationStore:
async def get_slack_conversation(
self, channel_id: str, parent_id: str
) -> SlackConversation | None:
"""
Get a slack conversation by channel_id and message_ts.
"""Get a slack conversation by channel_id and message_ts.
Both parameters are required to match for a conversation to be returned.
"""
with session_maker() as session:

View File

@@ -10,9 +10,7 @@ class SlackTeamStore:
session_maker: sessionmaker
def get_team_bot_token(self, team_id: str) -> str | None:
"""
Get a team's bot access token by team_id
"""
"""Get a team's bot access token by team_id"""
with session_maker() as session:
team = session.query(SlackTeam).filter(SlackTeam.team_id == team_id).first()
return team.bot_access_token if team else None
@@ -22,9 +20,7 @@ class SlackTeamStore:
team_id: str,
bot_access_token: str,
) -> SlackTeam:
"""
Create a new SlackTeam
"""
"""Create a new SlackTeam"""
slack_team = SlackTeam(team_id=team_id, bot_access_token=bot_access_token)
with session_maker() as session:
session.query(SlackTeam).filter(SlackTeam.team_id == team_id).delete()

View File

@@ -3,9 +3,7 @@ from storage.base import Base
class StoredRepository(Base): # type: ignore
"""
Represents a repositories fetched from git providers.
"""
"""Represents a repositories fetched from git providers."""
__tablename__ = 'repos'
id = Column(Integer, primary_key=True, autoincrement=True)

View File

@@ -5,9 +5,7 @@ from storage.base import Base
class StoredSettings(Base): # type: ignore
"""
Legacy user settings storage. This should be considered deprecated - use UserSettings isntead
"""
"""Legacy user settings storage. This should be considered deprecated - use UserSettings isntead"""
__tablename__ = 'settings'
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))

View File

@@ -3,8 +3,7 @@ from storage.base import Base
class StripeCustomer(Base): # type: ignore
"""
Represents a stripe customer. We can't simply use the stripe API for this because:
"""Represents a stripe customer. We can't simply use the stripe API for this because:
"Dont use search in read-after-write flows where strict consistency is necessary.
Under normal operating conditions, data is searchable in less than a minute.
Occasionally, propagation of new or updated data can be up to an hour behind during outages"

View File

@@ -5,8 +5,7 @@ from storage.base import Base
class SubscriptionAccess(Base): # type: ignore
"""
Represents a user's subscription access record.
"""Represents a user's subscription access record.
Tracks subscription status, duration, payment information, and cancellation status.
"""

View File

@@ -3,9 +3,7 @@ from storage.base import Base
class UserRepositoryMap(Base):
"""
Represents a map between user id and repo ids
"""
"""Represents a map between user id and repo ids"""
__tablename__ = 'user-repos'
id = Column(Integer, primary_key=True, autoincrement=True)

View File

@@ -16,8 +16,7 @@ class UserRepositoryMapStore:
config: OpenHandsConfig
def store_user_repo_mappings(self, mappings: list[UserRepositoryMap]) -> None:
"""
Store user-repository mappings in database
"""Store user-repository mappings in database
1. Make sure to store mappings if they don't exist
2. If a mapping already exists (same user_id and repo_id), update the admin field

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python3
"""
Common Room Sync
"""Common Room Sync
This script queries the database to count conversations created by each user,
then creates or updates a signal in Common Room for each user with their

View File

@@ -14,8 +14,7 @@ data_collector = GitHubDataCollector()
def get_unprocessed_prs() -> list[OpenhandsPR]:
"""
Get unprocessed PR entries from the OpenhandsPR table.
"""Get unprocessed PR entries from the OpenhandsPR table.
Args:
limit: Maximum number of PRs to retrieve (default: 50)
@@ -29,19 +28,14 @@ def get_unprocessed_prs() -> list[OpenhandsPR]:
async def process_pr(pr: OpenhandsPR):
"""
Process a single PR to enrich its data.
"""
"""Process a single PR to enrich its data."""
logger.info(f'Processing PR #{pr.pr_number} from repo {pr.repo_name}')
await data_collector.save_full_pr(pr)
store.increment_process_attempts(pr.repo_id, pr.pr_number)
async def main():
"""
Main function to retrieve and process unprocessed PRs.
"""
"""Main function to retrieve and process unprocessed PRs."""
logger.info('Starting PR data enrichment process')
# Get unprocessed PRs

View File

@@ -49,9 +49,7 @@ class VerifyWebhookStatus:
webhook_store: GitlabWebhookStore,
webhook: GitlabWebhook,
):
"""
Check if the GitLab resource still exists
"""
"""Check if the GitLab resource still exists"""
from integrations.gitlab.gitlab_service import SaaSGitLabService
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
@@ -83,9 +81,7 @@ class VerifyWebhookStatus:
webhook_store: GitlabWebhookStore,
webhook: GitlabWebhook,
):
"""
Check is user still has permission to resource
"""
"""Check is user still has permission to resource"""
from integrations.gitlab.gitlab_service import SaaSGitLabService
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
@@ -120,9 +116,7 @@ class VerifyWebhookStatus:
webhook_store: GitlabWebhookStore,
webhook: GitlabWebhook,
):
"""
Check whether webhook already exists on resource
"""
"""Check whether webhook already exists on resource"""
from integrations.gitlab.gitlab_service import SaaSGitLabService
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
@@ -192,9 +186,7 @@ class VerifyWebhookStatus:
webhook_store: GitlabWebhookStore,
webhook: GitlabWebhook,
):
"""
Install webhook on resource
"""
"""Install webhook on resource"""
from integrations.gitlab.gitlab_service import SaaSGitLabService
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
@@ -241,8 +233,7 @@ class VerifyWebhookStatus:
)
async def install_webhooks(self):
"""
Periodically check the conditions for installing a webhook on resource as valid
"""Periodically check the conditions for installing a webhook on resource as valid
Rows with valid conditions with contain (webhook_exists=False, status=WebhookStatus.VERIFIED)
Conditions we check for
@@ -255,7 +246,6 @@ class VerifyWebhookStatus:
- resource was never setup with webhook
"""
from integrations.gitlab.gitlab_service import SaaSGitLabService
# Get an instance of the webhook store

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python3
"""
Test script for Common Room conversation count sync.
"""Test script for Common Room conversation count sync.
This script tests the functionality of the Common Room sync script
without making any API calls to Common Room or database connections.

View File

@@ -1,6 +1,4 @@
"""
Shared fixtures for Jira integration tests.
"""
"""Shared fixtures for Jira integration tests."""
from unittest.mock import AsyncMock, MagicMock, patch

View File

@@ -1,6 +1,4 @@
"""
Unit tests for JiraManager.
"""
"""Unit tests for JiraManager."""
import hashlib
import hmac

View File

@@ -1,6 +1,4 @@
"""
Tests for Jira view classes and factory.
"""
"""Tests for Jira view classes and factory."""
from unittest.mock import AsyncMock, MagicMock, patch

View File

@@ -1,6 +1,4 @@
"""
Shared fixtures for Jira DC integration tests.
"""
"""Shared fixtures for Jira DC integration tests."""
from unittest.mock import AsyncMock, MagicMock, patch

View File

@@ -1,6 +1,4 @@
"""
Unit tests for JiraDcManager.
"""
"""Unit tests for JiraDcManager."""
import hashlib
import hmac

View File

@@ -1,6 +1,4 @@
"""
Tests for Jira DC view classes and factory.
"""
"""Tests for Jira DC view classes and factory."""
from unittest.mock import AsyncMock, MagicMock, patch

View File

@@ -1,6 +1,4 @@
"""
Shared fixtures for Linear integration tests.
"""
"""Shared fixtures for Linear integration tests."""
from unittest.mock import AsyncMock, MagicMock, patch

View File

@@ -1,6 +1,4 @@
"""
Unit tests for LinearManager.
"""
"""Unit tests for LinearManager."""
import hashlib
import hmac

View File

@@ -1,6 +1,4 @@
"""
Tests for Linear view classes and factory.
"""
"""Tests for Linear view classes and factory."""
from unittest.mock import AsyncMock, MagicMock, patch

View File

@@ -1,6 +1,4 @@
"""
Mock implementation of the stripe_service module for testing.
"""
"""Mock implementation of the stripe_service module for testing."""
from unittest.mock import AsyncMock, MagicMock

View File

@@ -1,6 +1,4 @@
"""
Tests for conversation_callback_utils.py
"""
"""Tests for conversation_callback_utils.py"""
from unittest.mock import Mock, patch

View File

@@ -1,6 +1,4 @@
"""
Shared fixtures for all tests.
"""
"""Shared fixtures for all tests."""
from typing import Any
from unittest.mock import MagicMock
@@ -44,8 +42,7 @@ def feature_embedding() -> FeatureEmbedding:
@pytest.fixture
def featurizer(mock_llm, features) -> Featurizer:
"""
Create a featurizer for testing.
"""Create a featurizer for testing.
Mocks out any calls to LLM.completion
"""

View File

@@ -1,6 +1,4 @@
"""
Unit tests for data loading functionality in solvability/data.
"""
"""Unit tests for data loading functionality in solvability/data."""
import json
import tempfile

View File

@@ -38,7 +38,6 @@ def mock_response():
def test_set_response_cookie(mock_response, mock_request):
"""Test setting the auth cookie on a response."""
with patch('server.routes.auth.config') as mock_config:
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'

View File

@@ -1,5 +1,4 @@
"""
This test file verifies that the billing routes correctly use the stripe_service
"""This test file verifies that the billing routes correctly use the stripe_service
functions with the new database-first approach.
"""
@@ -71,7 +70,6 @@ async def test_create_customer_setup_session_uses_customer_id():
@pytest.mark.asyncio
async def test_create_checkout_session_uses_customer_id():
"""Test that create_checkout_session uses a customer ID string"""
# Create a mock request
mock_request = MagicMock()
mock_request.state = {'user_id': 'test-user-id'}
@@ -157,7 +155,6 @@ async def test_create_checkout_session_uses_customer_id():
@pytest.mark.asyncio
async def test_has_payment_method_uses_customer_id():
"""Test that has_payment_method uses a customer ID string"""
# Create a mock request
mock_request = MagicMock()
mock_request.state = {'user_id': 'test-user-id'}

Some files were not shown because too many files have changed in this diff Show More