mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
9 Commits
fix-basic-
...
APP-19
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6776b8b34e | ||
|
|
14f50a4816 | ||
|
|
fa4278b3b0 | ||
|
|
4b2dfc1f0d | ||
|
|
24d5a33b0e | ||
|
|
ba1edf0f0a | ||
|
|
741ea28cd3 | ||
|
|
191e3b9cef | ||
|
|
eac39e85eb |
@@ -17,8 +17,7 @@ class SaaSExperimentManager(ExperimentManager):
|
||||
def run_conversation_variant_test(
|
||||
user_id, conversation_id, conversation_settings
|
||||
) -> ConversationInitData:
|
||||
"""
|
||||
Run conversation variant test and potentially modify the conversation settings
|
||||
"""Run conversation variant test and potentially modify the conversation settings
|
||||
based on the PostHog feature flags.
|
||||
|
||||
Args:
|
||||
@@ -53,8 +52,7 @@ class SaaSExperimentManager(ExperimentManager):
|
||||
def run_config_variant_test(
|
||||
user_id: str | None, conversation_id: str, config: OpenHandsConfig
|
||||
) -> OpenHandsConfig:
|
||||
"""
|
||||
Run agent config variant test and potentially modify the OpenHands config
|
||||
"""Run agent config variant test and potentially modify the OpenHands config
|
||||
based on the current experiment type and PostHog feature flags.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
LiteLLM model experiment handler.
|
||||
"""LiteLLM model experiment handler.
|
||||
|
||||
This module contains the handler for the LiteLLM model experiment.
|
||||
"""
|
||||
@@ -18,8 +17,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
def handle_litellm_default_model_experiment(
|
||||
user_id, conversation_id, conversation_settings
|
||||
):
|
||||
"""
|
||||
Handle the LiteLLM model experiment.
|
||||
"""Handle the LiteLLM model experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
System prompt experiment handler.
|
||||
"""System prompt experiment handler.
|
||||
|
||||
This module contains the handler for the system prompt experiment that uses
|
||||
the PostHog variant as the system prompt filename.
|
||||
@@ -17,8 +16,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def _get_system_prompt_variant(user_id, conversation_id):
|
||||
"""
|
||||
Get the system prompt variant for the experiment.
|
||||
"""Get the system prompt variant for the experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
@@ -119,8 +117,7 @@ def _get_system_prompt_variant(user_id, conversation_id):
|
||||
def handle_system_prompt_experiment(
|
||||
user_id, conversation_id, config: OpenHandsConfig
|
||||
) -> OpenHandsConfig:
|
||||
"""
|
||||
Handle the system prompt experiment for OpenHands config.
|
||||
"""Handle the system prompt experiment for OpenHands config.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
LiteLLM model experiment handler.
|
||||
"""LiteLLM model experiment handler.
|
||||
|
||||
This module contains the handler for the LiteLLM model experiment.
|
||||
"""
|
||||
@@ -110,8 +109,7 @@ def handle_claude4_vs_gpt5_experiment(
|
||||
conversation_id: str,
|
||||
conversation_settings: ConversationInitData,
|
||||
) -> ConversationInitData:
|
||||
"""
|
||||
Handle the LiteLLM model experiment.
|
||||
"""Handle the LiteLLM model experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
@@ -121,7 +119,6 @@ def handle_claude4_vs_gpt5_experiment(
|
||||
Returns:
|
||||
Modified conversation settings
|
||||
"""
|
||||
|
||||
enabled_variant = _get_model_variant(user_id, conversation_id)
|
||||
|
||||
if not enabled_variant:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Condenser max step experiment handler.
|
||||
"""Condenser max step experiment handler.
|
||||
|
||||
This module contains the handler for the condenser max step experiment that tests
|
||||
different max_size values for the condenser configuration.
|
||||
@@ -15,8 +14,7 @@ from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
|
||||
|
||||
def _get_condenser_max_step_variant(user_id, conversation_id):
|
||||
"""
|
||||
Get the condenser max step variant for the experiment.
|
||||
"""Get the condenser max step variant for the experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
@@ -119,8 +117,7 @@ def handle_condenser_max_step_experiment(
|
||||
conversation_id: str,
|
||||
conversation_settings: ConversationInitData,
|
||||
) -> ConversationInitData:
|
||||
"""
|
||||
Handle the condenser max step experiment for conversation settings.
|
||||
"""Handle the condenser max step experiment for conversation settings.
|
||||
|
||||
We should not modify persistent user settings. Instead, apply the experiment
|
||||
variant to the conversation's in-memory settings object for this session only.
|
||||
@@ -131,7 +128,6 @@ def handle_condenser_max_step_experiment(
|
||||
|
||||
Returns the (potentially) modified conversation_settings.
|
||||
"""
|
||||
|
||||
enabled_variant = _get_condenser_max_step_variant(user_id, conversation_id)
|
||||
|
||||
if enabled_variant is None:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Experiment versions package.
|
||||
"""Experiment versions package.
|
||||
|
||||
This package contains handlers for different experiment versions.
|
||||
"""
|
||||
|
||||
@@ -43,8 +43,7 @@ class TriggerType(str, Enum):
|
||||
|
||||
|
||||
class GitHubDataCollector:
|
||||
"""
|
||||
Saves data on Cloud Resolver Interactions
|
||||
"""Saves data on Cloud Resolver Interactions
|
||||
|
||||
1. We always save
|
||||
- Resolver trigger (comment or label)
|
||||
@@ -89,8 +88,7 @@ class GitHubDataCollector:
|
||||
self.conversation_id = None
|
||||
|
||||
async def _get_repo_node_id(self, repo_id: str, gh_client) -> str:
|
||||
"""
|
||||
Get the new GitHub GraphQL node ID for a repository using the GitHub client.
|
||||
"""Get the new GitHub GraphQL node ID for a repository using the GitHub client.
|
||||
|
||||
Args:
|
||||
repo_id: Numeric repository ID as string (e.g., "123456789")
|
||||
@@ -136,10 +134,7 @@ class GitHubDataCollector:
|
||||
def _get_issue_comments(
|
||||
self, installation_id: str, repo_name: str, issue_number: int, conversation_id
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve all comments from an issue until a comment with conversation_id is found
|
||||
"""
|
||||
|
||||
"""Retrieve all comments from an issue until a comment with conversation_id is found"""
|
||||
try:
|
||||
installation_token = self._get_installation_access_token(installation_id)
|
||||
|
||||
@@ -175,18 +170,16 @@ class GitHubDataCollector:
|
||||
github_view: GithubIssue,
|
||||
trigger_type: TriggerType,
|
||||
) -> None:
|
||||
"""
|
||||
Save issue data when it's labeled with openhands
|
||||
"""Save issue data when it's labeled with openhands
|
||||
|
||||
1. Save under {conversation_dir}/{conversation_id}/github_data/issue_{issue_number}.json
|
||||
2. Save issue snapshot (title, body, comments)
|
||||
3. Save trigger type (label)
|
||||
4. Save PR opened (if exists, this information comes later when agent has finished its task)
|
||||
- Save commit shas
|
||||
- Save author info
|
||||
5. Was PR merged or closed
|
||||
1. Save under {conversation_dir}/{conversation_id}/github_data/issue_{issue_number}.json
|
||||
2. Save issue snapshot (title, body, comments)
|
||||
3. Save trigger type (label)
|
||||
4. Save PR opened (if exists, this information comes later when agent has finished its task)
|
||||
- Save commit shas
|
||||
- Save author info
|
||||
5. Was PR merged or closed
|
||||
"""
|
||||
|
||||
conversation_id = github_view.conversation_id
|
||||
|
||||
if not conversation_id:
|
||||
@@ -385,7 +378,6 @@ class GitHubDataCollector:
|
||||
openhands_general_comment_count: int = 0,
|
||||
) -> dict:
|
||||
"""Build the final data structure for JSON storage"""
|
||||
|
||||
is_merged = pr_data['merged']
|
||||
merged_by = None
|
||||
merge_commit_sha = None
|
||||
@@ -419,8 +411,7 @@ class GitHubDataCollector:
|
||||
}
|
||||
|
||||
async def save_full_pr(self, openhands_pr: OpenhandsPR) -> None:
|
||||
"""
|
||||
Save PR information including metadata and commit details using GraphQL
|
||||
"""Save PR information including metadata and commit details using GraphQL
|
||||
|
||||
Saves:
|
||||
- Repo metadata (repo name, languages, contributors)
|
||||
@@ -606,17 +597,12 @@ class GitHubDataCollector:
|
||||
return None
|
||||
|
||||
def _is_pr_closed_or_merged(self, payload):
|
||||
"""
|
||||
Check if PR was closed (regardless of conversation URL)
|
||||
"""
|
||||
"""Check if PR was closed (regardless of conversation URL)"""
|
||||
action = payload.get('action', '')
|
||||
return action == 'closed' and 'pull_request' in payload
|
||||
|
||||
def _track_closed_or_merged_pr(self, payload):
|
||||
"""
|
||||
Track PR closed/merged event
|
||||
"""
|
||||
|
||||
"""Track PR closed/merged event"""
|
||||
repo_id = str(payload['repository']['id'])
|
||||
pr_number = payload['number']
|
||||
installation_id = str(payload['installation']['id'])
|
||||
|
||||
@@ -103,8 +103,7 @@ class SaaSGitHubService(GitHubService):
|
||||
}
|
||||
|
||||
async def get_repository_node_id(self, repo_id: str) -> str:
|
||||
"""
|
||||
Get the new GitHub GraphQL node ID for a repository using REST API.
|
||||
"""Get the new GitHub GraphQL node ID for a repository using REST API.
|
||||
|
||||
Args:
|
||||
repo_id: Numeric repository ID as string (e.g., "123456789")
|
||||
|
||||
@@ -39,7 +39,6 @@ def fetch_github_issue_context(
|
||||
Returns:
|
||||
A comprehensive string containing the issue/PR context
|
||||
"""
|
||||
|
||||
# Build context string
|
||||
context_parts = []
|
||||
|
||||
|
||||
@@ -55,7 +55,6 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
This function checks both the global environment variable kill switch AND
|
||||
the user's individual setting. Both must be true for the function to return true.
|
||||
"""
|
||||
|
||||
# If no user ID is provided, we can't check user settings
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
@@ -43,8 +43,7 @@ class GitlabManager(Manager):
|
||||
async def _user_has_write_access_to_repo(
|
||||
self, project_id: str, user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the user has write access to the repository (can pull/push changes and open merge requests).
|
||||
"""Check if the user has write access to the repository (can pull/push changes and open merge requests).
|
||||
|
||||
Args:
|
||||
project_id: The ID of the GitLab project
|
||||
@@ -54,7 +53,6 @@ class GitlabManager(Manager):
|
||||
Returns:
|
||||
bool: True if the user has write access to the repository, False otherwise
|
||||
"""
|
||||
|
||||
keycloak_user_id = await self.token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITLAB
|
||||
)
|
||||
@@ -117,8 +115,7 @@ class GitlabManager(Manager):
|
||||
return has_write_access
|
||||
|
||||
async def send_message(self, message: Message, gitlab_view: ResolverViewInterface):
|
||||
"""
|
||||
Send a message to GitLab based on the view type.
|
||||
"""Send a message to GitLab based on the view type.
|
||||
|
||||
Args:
|
||||
message: The message to send
|
||||
@@ -165,8 +162,7 @@ class GitlabManager(Manager):
|
||||
)
|
||||
|
||||
async def start_job(self, gitlab_view: GitlabViewType):
|
||||
"""
|
||||
Start a job for the GitLab view.
|
||||
"""Start a job for the GitLab view.
|
||||
|
||||
Args:
|
||||
gitlab_view: The GitLab view object containing issue/PR/comment info
|
||||
|
||||
@@ -81,8 +81,7 @@ class SaaSGitLabService(GitLabService):
|
||||
return gitlab_token
|
||||
|
||||
async def get_owned_groups(self) -> list[dict]:
|
||||
"""
|
||||
Get all groups for which the current user is the owner.
|
||||
"""Get all groups for which the current user is the owner.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of groups owned by the current user.
|
||||
@@ -98,8 +97,7 @@ class SaaSGitLabService(GitLabService):
|
||||
return []
|
||||
|
||||
async def add_owned_projects_and_groups_to_db(self, owned_personal_projects):
|
||||
"""
|
||||
Add owned projects and groups to the database for webhook tracking.
|
||||
"""Add owned projects and groups to the database for webhook tracking.
|
||||
|
||||
Args:
|
||||
owned_personal_projects: List of personal projects owned by the user
|
||||
@@ -147,8 +145,7 @@ class SaaSGitLabService(GitLabService):
|
||||
async def store_repository_data(
|
||||
self, users_personal_projects: list[dict], repositories: list[Repository]
|
||||
) -> None:
|
||||
"""
|
||||
Store repository data in the database.
|
||||
"""Store repository data in the database.
|
||||
This function combines the functionality of add_owned_projects_and_groups_to_db and store_repositories_in_db.
|
||||
|
||||
Args:
|
||||
@@ -171,8 +168,7 @@ class SaaSGitLabService(GitLabService):
|
||||
async def get_all_repositories(
|
||||
self, sort: str, app_mode: AppMode, store_in_background: bool = True
|
||||
) -> list[Repository]:
|
||||
"""
|
||||
Get repositories for the authenticated user, including information about the kind of project.
|
||||
"""Get repositories for the authenticated user, including information about the kind of project.
|
||||
Also collects repositories where the kind is "user" and the user is the owner.
|
||||
|
||||
Args:
|
||||
@@ -270,8 +266,7 @@ class SaaSGitLabService(GitLabService):
|
||||
async def check_resource_exists(
|
||||
self, resource_type: GitLabResourceType, resource_id: str
|
||||
) -> tuple[bool, WebhookStatus | None]:
|
||||
"""
|
||||
Check if resource exists and the user has access to it.
|
||||
"""Check if resource exists and the user has access to it.
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
@@ -282,7 +277,6 @@ class SaaSGitLabService(GitLabService):
|
||||
- bool: True if the resource exists and the user has access to it, False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
if resource_type == GitLabResourceType.GROUP:
|
||||
url = f'{self.BASE_URL}/groups/{resource_id}'
|
||||
else:
|
||||
@@ -301,8 +295,7 @@ class SaaSGitLabService(GitLabService):
|
||||
async def check_webhook_exists_on_resource(
|
||||
self, resource_type: GitLabResourceType, resource_id: str, webhook_url: str
|
||||
) -> tuple[bool, WebhookStatus | None]:
|
||||
"""
|
||||
Check if a webhook already exists for resource with a specific URL.
|
||||
"""Check if a webhook already exists for resource with a specific URL.
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
@@ -314,7 +307,6 @@ class SaaSGitLabService(GitLabService):
|
||||
- bool: True if the webhook exists, False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
# Construct the URL based on the resource type
|
||||
if resource_type == GitLabResourceType.GROUP:
|
||||
url = f'{self.BASE_URL}/groups/{resource_id}/hooks'
|
||||
@@ -343,8 +335,7 @@ class SaaSGitLabService(GitLabService):
|
||||
async def check_user_has_admin_access_to_resource(
|
||||
self, resource_type: GitLabResourceType, resource_id: str
|
||||
) -> tuple[bool, WebhookStatus | None]:
|
||||
"""
|
||||
Check if the user has admin access to resource (is either an owner or maintainer)
|
||||
"""Check if the user has admin access to resource (is either an owner or maintainer)
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
@@ -355,7 +346,6 @@ class SaaSGitLabService(GitLabService):
|
||||
- bool: True if the user has admin access to the resource (owner or maintainer), False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
# For groups, we need to check if the user is an owner or maintainer
|
||||
if resource_type == GitLabResourceType.GROUP:
|
||||
url = f'{self.BASE_URL}/groups/{resource_id}/members/all'
|
||||
@@ -412,8 +402,7 @@ class SaaSGitLabService(GitLabService):
|
||||
webhook_uuid: str,
|
||||
scopes: list[str],
|
||||
) -> tuple[str | None, WebhookStatus | None]:
|
||||
"""
|
||||
Install webhook for user's group or project
|
||||
"""Install webhook for user's group or project
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
@@ -428,7 +417,6 @@ class SaaSGitLabService(GitLabService):
|
||||
- bool: True if installation was successful, False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
description = 'Cloud OpenHands Resolver'
|
||||
|
||||
# Set up webhook parameters
|
||||
@@ -500,9 +488,7 @@ class SaaSGitLabService(GitLabService):
|
||||
async def reply_to_issue(
|
||||
self, project_id: str, issue_number: str, discussion_id: str | None, body: str
|
||||
):
|
||||
"""
|
||||
Either create new comment thread, or reply to comment thread (depending on discussion_id param)
|
||||
"""
|
||||
"""Either create new comment thread, or reply to comment thread (depending on discussion_id param)"""
|
||||
try:
|
||||
if discussion_id:
|
||||
url = f'{self.BASE_URL}/projects/{project_id}/issues/{issue_number}/discussions/{discussion_id}/notes'
|
||||
@@ -517,9 +503,7 @@ class SaaSGitLabService(GitLabService):
|
||||
async def reply_to_mr(
|
||||
self, project_id: str, merge_request_iid: str, discussion_id: str, body: str
|
||||
):
|
||||
"""
|
||||
Reply to comment thread on MR
|
||||
"""
|
||||
"""Reply to comment thread on MR"""
|
||||
try:
|
||||
url = f'{self.BASE_URL}/projects/{project_id}/merge_requests/{merge_request_iid}/discussions/{discussion_id}/notes'
|
||||
params = {'body': body}
|
||||
|
||||
@@ -48,7 +48,6 @@ class JiraManager(Manager):
|
||||
self, jira_user_id: str, workspace_id: int
|
||||
) -> tuple[JiraUser | None, UserAuth | None]:
|
||||
"""Authenticate Jira user and get their OpenHands user auth."""
|
||||
|
||||
# Find active Jira user by Keycloak user ID and workspace ID
|
||||
jira_user = await self.integration_store.get_active_user(
|
||||
jira_user_id, workspace_id
|
||||
@@ -206,7 +205,6 @@ class JiraManager(Manager):
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira webhook message."""
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
job_context = self.parse_webhook(payload)
|
||||
|
||||
@@ -299,10 +297,7 @@ class JiraManager(Manager):
|
||||
async def is_job_requested(
|
||||
self, message: Message, jira_view: JiraViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
"""Check if a job is requested and handle repository selection."""
|
||||
if isinstance(jira_view, JiraExistingConversationView):
|
||||
return True
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
@@ -52,7 +51,6 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Jira conversation"""
|
||||
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
@@ -112,7 +110,6 @@ class JiraExistingConversationView(JiraViewInterface):
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
@@ -125,7 +122,6 @@ class JiraExistingConversationView(JiraViewInterface):
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Jira conversation"""
|
||||
|
||||
user_id = self.jira_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
@@ -191,7 +187,6 @@ class JiraFactory:
|
||||
jira_workspace: JiraWorkspace,
|
||||
) -> JiraViewInterface:
|
||||
"""Create appropriate Jira view based on the message and user state"""
|
||||
|
||||
if not jira_user or not saas_user_auth or not jira_workspace:
|
||||
raise StartingConvoException('User not authenticated with Jira integration')
|
||||
|
||||
|
||||
@@ -48,7 +48,6 @@ class JiraDcManager(Manager):
|
||||
self, user_email: str, jira_dc_user_id: str, workspace_id: int
|
||||
) -> tuple[JiraDcUser | None, UserAuth | None]:
|
||||
"""Authenticate Jira DC user and get their OpenHands user auth."""
|
||||
|
||||
if not jira_dc_user_id or jira_dc_user_id == 'none':
|
||||
# Get Keycloak user ID from email
|
||||
keycloak_user_id = await self.token_manager.get_user_id_from_user_email(
|
||||
@@ -221,7 +220,6 @@ class JiraDcManager(Manager):
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira DC webhook message."""
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
job_context = self.parse_webhook(payload)
|
||||
|
||||
@@ -315,10 +313,7 @@ class JiraDcManager(Manager):
|
||||
async def is_job_requested(
|
||||
self, message: Message, jira_dc_view: JiraDcViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
"""Check if a job is requested and handle repository selection."""
|
||||
if isinstance(jira_dc_view, JiraDcExistingConversationView):
|
||||
return True
|
||||
|
||||
|
||||
@@ -38,7 +38,6 @@ class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_dc_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
@@ -55,7 +54,6 @@ class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Jira DC conversation"""
|
||||
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
@@ -115,7 +113,6 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_dc_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
@@ -128,7 +125,6 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Jira conversation"""
|
||||
|
||||
user_id = self.jira_dc_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
@@ -195,7 +191,6 @@ class JiraDcFactory:
|
||||
jira_dc_workspace: JiraDcWorkspace,
|
||||
) -> JiraDcViewInterface:
|
||||
"""Create appropriate Jira DC view based on the payload."""
|
||||
|
||||
if not jira_dc_user or not saas_user_auth or not jira_dc_workspace:
|
||||
raise StartingConvoException('User not authenticated with Jira integration')
|
||||
|
||||
|
||||
@@ -46,7 +46,6 @@ class LinearManager(Manager):
|
||||
self, linear_user_id: str, workspace_id: int
|
||||
) -> tuple[LinearUser | None, UserAuth | None]:
|
||||
"""Authenticate Linear user and get their OpenHands user auth."""
|
||||
|
||||
# Find active Linear user by Linear user ID and workspace ID
|
||||
linear_user = await self.integration_store.get_active_user(
|
||||
linear_user_id, workspace_id
|
||||
@@ -305,10 +304,7 @@ class LinearManager(Manager):
|
||||
async def is_job_requested(
|
||||
self, message: Message, linear_view: LinearViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
"""Check if a job is requested and handle repository selection."""
|
||||
if isinstance(linear_view, LinearExistingConversationView):
|
||||
return True
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ class LinearNewConversationView(LinearViewInterface):
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('linear_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
@@ -52,7 +51,6 @@ class LinearNewConversationView(LinearViewInterface):
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Linear conversation"""
|
||||
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
@@ -112,7 +110,6 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('linear_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
@@ -125,7 +122,6 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Linear conversation"""
|
||||
|
||||
user_id = self.linear_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
@@ -192,7 +188,6 @@ class LinearFactory:
|
||||
linear_workspace: LinearWorkspace,
|
||||
) -> LinearViewInterface:
|
||||
"""Create appropriate Linear view based on the message and user state"""
|
||||
|
||||
if not linear_user or not saas_user_auth or not linear_workspace:
|
||||
raise StartingConvoException(
|
||||
'User not authenticated with Linear integration'
|
||||
|
||||
@@ -8,22 +8,22 @@ class Manager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def receive_message(self, message: Message):
|
||||
"Receive message from integration"
|
||||
"""Receive message from integration"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_message(self, message: Message):
|
||||
"Send message to integration from Openhands server"
|
||||
"""Send message to integration from Openhands server"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def is_job_requested(self, message: Message) -> bool:
|
||||
"Confirm that a job is being requested"
|
||||
"""Confirm that a job is being requested"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self):
|
||||
"Kick off a job with openhands agent"
|
||||
"""Kick off a job with openhands agent"""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_outgoing_message(self, msg: str | dict, ephemeral: bool = False):
|
||||
|
||||
@@ -244,13 +244,11 @@ class SlackManager(Manager):
|
||||
async def is_job_requested(
|
||||
self, message: Message, slack_view: SlackViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
A job is always request we only receive webhooks for events associated with the slack bot
|
||||
"""A job is always request we only receive webhooks for events associated with the slack bot
|
||||
This method really just checks
|
||||
1. Is the user is authenticated
|
||||
2. Do we have the necessary information to start a job (either by inferring the selected repo, otherwise asking the user)
|
||||
"""
|
||||
|
||||
# Infer repo from user message is not needed; user selected repo from the form or is updating existing convo
|
||||
if isinstance(slack_view, SlackUpdateExistingConversationView):
|
||||
return True
|
||||
|
||||
@@ -24,17 +24,17 @@ class SlackViewInterface(SummaryExtractionTracker, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_or_update_conversation(self, jinja_env: Environment):
|
||||
"Create a new conversation"
|
||||
"""Create a new conversation"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_callback_id(self) -> str:
|
||||
"Unique callback id for subscribription made to EventStream for fetching agent summary"
|
||||
"""Unique callback id for subscribription made to EventStream for fetching agent summary"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -43,6 +43,4 @@ class SlackViewInterface(SummaryExtractionTracker, ABC):
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""
|
||||
Raised when trying to send message to a conversation that's is still starting up
|
||||
"""
|
||||
"""Raised when trying to send message to a conversation that's is still starting up"""
|
||||
|
||||
@@ -95,8 +95,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
return ''
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
|
||||
messages = []
|
||||
@@ -179,9 +178,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
await slack_conversation_store.create_slack_conversation(slack_conversation)
|
||||
|
||||
async def create_or_update_conversation(self, jinja: Environment) -> str:
|
||||
"""
|
||||
Only creates a new conversation
|
||||
"""
|
||||
"""Only creates a new conversation"""
|
||||
self._verify_necessary_values_are_set()
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
@@ -246,9 +243,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
return user_message, ''
|
||||
|
||||
async def create_or_update_conversation(self, jinja: Environment) -> str:
|
||||
"""
|
||||
Send new user message to converation
|
||||
"""
|
||||
"""Send new user message to converation"""
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
saas_user_auth: UserAuth = self.saas_user_auth
|
||||
user_id = user_info.keycloak_user_id
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Utilities for loading and managing pre-trained classifiers.
|
||||
"""Utilities for loading and managing pre-trained classifiers.
|
||||
|
||||
Assumes that classifiers are stored adjacent to this file in the `solvability/data` directory, using a simple
|
||||
`name + .json` pattern.
|
||||
@@ -11,8 +10,7 @@ from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||
|
||||
|
||||
def load_classifier(name: str) -> SolvabilityClassifier:
|
||||
"""
|
||||
Load a classifier by name.
|
||||
"""Load a classifier by name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the classifier to load.
|
||||
@@ -31,8 +29,7 @@ def load_classifier(name: str) -> SolvabilityClassifier:
|
||||
|
||||
|
||||
def available_classifiers() -> list[str]:
|
||||
"""
|
||||
List all available classifiers in the data directory.
|
||||
"""List all available classifiers in the data directory.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of classifier names (without the .json extension).
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Solvability Models Package
|
||||
"""Solvability Models Package
|
||||
|
||||
This package contains the core machine learning models and components for predicting
|
||||
the solvability of GitHub issues and similar technical problems.
|
||||
|
||||
@@ -26,8 +26,7 @@ from openhands.core.config import LLMConfig
|
||||
|
||||
|
||||
class SolvabilityClassifier(BaseModel):
|
||||
"""
|
||||
Machine learning pipeline for predicting the solvability of GitHub issues and similar problems.
|
||||
"""Machine learning pipeline for predicting the solvability of GitHub issues and similar problems.
|
||||
|
||||
This classifier combines LLM-based feature extraction with traditional ML classification:
|
||||
1. Uses a Featurizer to extract semantic boolean features from issue descriptions via LLM calls
|
||||
@@ -87,9 +86,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_random_state(self) -> SolvabilityClassifier:
|
||||
"""
|
||||
Validate the random state configuration between this object and the classifier.
|
||||
"""
|
||||
"""Validate the random state configuration between this object and the classifier."""
|
||||
# If both random states are set, they definitely need to agree.
|
||||
if self.random_state is not None and self.classifier.random_state is not None:
|
||||
if self.random_state != self.classifier.random_state:
|
||||
@@ -104,9 +101,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
|
||||
@property
|
||||
def features_(self) -> pd.DataFrame:
|
||||
"""
|
||||
Get the features used by the classifier for the most recent inputs.
|
||||
"""
|
||||
"""Get the features used by the classifier for the most recent inputs."""
|
||||
if 'features_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'SolvabilityClassifier.transform() has not yet been called.'
|
||||
@@ -115,9 +110,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
|
||||
@property
|
||||
def cost_(self) -> pd.DataFrame:
|
||||
"""
|
||||
Get the cost of the classifier for the most recent inputs.
|
||||
"""
|
||||
"""Get the cost of the classifier for the most recent inputs."""
|
||||
if 'cost_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'SolvabilityClassifier.transform() has not yet been called.'
|
||||
@@ -126,9 +119,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
|
||||
@property
|
||||
def feature_importances_(self) -> np.ndarray:
|
||||
"""
|
||||
Get the feature importances for the most recent inputs.
|
||||
"""
|
||||
"""Get the feature importances for the most recent inputs."""
|
||||
if 'feature_importances_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'No SolvabilityClassifier methods that produce feature importances (.fit(), .predict_proba(), and '
|
||||
@@ -138,9 +129,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
|
||||
@property
|
||||
def is_fitted(self) -> bool:
|
||||
"""
|
||||
Check if the classifier is fitted.
|
||||
"""
|
||||
"""Check if the classifier is fitted."""
|
||||
try:
|
||||
check_is_fitted(self.classifier)
|
||||
return True
|
||||
@@ -148,8 +137,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
return False
|
||||
|
||||
def transform(self, issues: pd.Series, llm_config: LLMConfig) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input issues using the featurizer to extract features.
|
||||
"""Transform the input issues using the featurizer to extract features.
|
||||
|
||||
This method orchestrates the feature extraction pipeline:
|
||||
1. Uses the featurizer to generate embeddings for all issues
|
||||
@@ -183,8 +171,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
def fit(
|
||||
self, issues: pd.Series, labels: pd.Series, llm_config: LLMConfig
|
||||
) -> SolvabilityClassifier:
|
||||
"""
|
||||
Fit the classifier to the input issues and labels.
|
||||
"""Fit the classifier to the input issues and labels.
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
@@ -208,8 +195,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
return self
|
||||
|
||||
def predict_proba(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
|
||||
"""
|
||||
Predict the solvability probabilities for the input issues.
|
||||
"""Predict the solvability probabilities for the input issues.
|
||||
|
||||
Returns class probabilities where the second column represents the probability
|
||||
of the issue being solvable (positive class).
|
||||
@@ -243,8 +229,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
return scores # type: ignore[no-any-return]
|
||||
|
||||
def predict(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
|
||||
"""
|
||||
Predict the solvability of the input issues by returning binary labels.
|
||||
"""Predict the solvability of the input issues by returning binary labels.
|
||||
|
||||
Uses a 0.5 probability threshold to convert probabilities to binary predictions.
|
||||
|
||||
@@ -266,8 +251,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
scores: np.ndarray,
|
||||
labels: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Calculate feature importance scores using the configured strategy.
|
||||
"""Calculate feature importance scores using the configured strategy.
|
||||
|
||||
Different strategies provide different interpretations:
|
||||
- SHAP: Shapley values indicating contribution to individual predictions
|
||||
@@ -313,8 +297,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
)
|
||||
|
||||
def add_features(self, features: list[Feature]) -> SolvabilityClassifier:
|
||||
"""
|
||||
Add new features to the classifier's featurizer.
|
||||
"""Add new features to the classifier's featurizer.
|
||||
|
||||
Note: Adding features after training requires retraining the classifier
|
||||
since the feature space will have changed.
|
||||
@@ -331,8 +314,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
return self
|
||||
|
||||
def forget_features(self, features: list[Feature]) -> SolvabilityClassifier:
|
||||
"""
|
||||
Remove features from the classifier's featurizer.
|
||||
"""Remove features from the classifier's featurizer.
|
||||
|
||||
Note: Removing features after training requires retraining the classifier
|
||||
since the feature space will have changed.
|
||||
@@ -354,17 +336,13 @@ class SolvabilityClassifier(BaseModel):
|
||||
@field_serializer('classifier')
|
||||
@staticmethod
|
||||
def _rfc_to_json(rfc: RandomForestClassifier) -> str:
|
||||
"""
|
||||
Convert a RandomForestClassifier to a JSON-compatible value (a string).
|
||||
"""
|
||||
"""Convert a RandomForestClassifier to a JSON-compatible value (a string)."""
|
||||
return base64.b64encode(pickle.dumps(rfc)).decode('utf-8')
|
||||
|
||||
@field_validator('classifier', mode='before')
|
||||
@staticmethod
|
||||
def _json_to_rfc(value: str | RandomForestClassifier) -> RandomForestClassifier:
|
||||
"""
|
||||
Convert a JSON-compatible value (a string) back to a RandomForestClassifier.
|
||||
"""
|
||||
"""Convert a JSON-compatible value (a string) back to a RandomForestClassifier."""
|
||||
if isinstance(value, RandomForestClassifier):
|
||||
return value
|
||||
|
||||
@@ -383,8 +361,7 @@ class SolvabilityClassifier(BaseModel):
|
||||
def solvability_report(
|
||||
self, issue: str, llm_config: LLMConfig, **kwargs: Any
|
||||
) -> SolvabilityReport:
|
||||
"""
|
||||
Generate a solvability report for the given issue.
|
||||
"""Generate a solvability report for the given issue.
|
||||
|
||||
Args:
|
||||
issue: The issue description for which to generate the report.
|
||||
@@ -427,7 +404,5 @@ class SolvabilityClassifier(BaseModel):
|
||||
def __call__(
|
||||
self, issue: str, llm_config: LLMConfig, **kwargs: Any
|
||||
) -> SolvabilityReport:
|
||||
"""
|
||||
Generate a solvability report for the given issue.
|
||||
"""
|
||||
"""Generate a solvability report for the given issue."""
|
||||
return self.solvability_report(issue, llm_config=llm_config, **kwargs)
|
||||
|
||||
@@ -10,8 +10,7 @@ from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
class Feature(BaseModel):
|
||||
"""
|
||||
Represents a single boolean feature that can be extracted from issue descriptions.
|
||||
"""Represents a single boolean feature that can be extracted from issue descriptions.
|
||||
|
||||
Features are semantic properties of issues (e.g., "has_code_example", "requires_debugging")
|
||||
that are evaluated by LLMs and used as input to the solvability classifier.
|
||||
@@ -25,8 +24,7 @@ class Feature(BaseModel):
|
||||
|
||||
@property
|
||||
def to_tool_description_field(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert this feature to a JSON schema field for LLM tool calling.
|
||||
"""Convert this feature to a JSON schema field for LLM tool calling.
|
||||
|
||||
Returns:
|
||||
dict: JSON schema field definition for this feature.
|
||||
@@ -38,8 +36,7 @@ class Feature(BaseModel):
|
||||
|
||||
|
||||
class EmbeddingDimension(BaseModel):
|
||||
"""
|
||||
Represents a single dimension (feature evaluation) within a feature embedding sample.
|
||||
"""Represents a single dimension (feature evaluation) within a feature embedding sample.
|
||||
|
||||
Each dimension corresponds to one feature being evaluated as true/false for a given issue.
|
||||
"""
|
||||
@@ -60,8 +57,7 @@ Maps feature identifiers to their boolean evaluations.
|
||||
|
||||
|
||||
class FeatureEmbedding(BaseModel):
|
||||
"""
|
||||
Represents the complete feature embedding for a single issue, including multiple samples
|
||||
"""Represents the complete feature embedding for a single issue, including multiple samples
|
||||
and associated metadata about the LLM calls used to generate it.
|
||||
|
||||
Multiple samples are collected to account for LLM variability and provide more robust
|
||||
@@ -82,8 +78,7 @@ class FeatureEmbedding(BaseModel):
|
||||
|
||||
@property
|
||||
def dimensions(self) -> list[str]:
|
||||
"""
|
||||
Get all unique feature identifiers present across all samples.
|
||||
"""Get all unique feature identifiers present across all samples.
|
||||
|
||||
Returns:
|
||||
list[str]: List of feature identifiers that appear in at least one sample.
|
||||
@@ -94,8 +89,7 @@ class FeatureEmbedding(BaseModel):
|
||||
return list(dims)
|
||||
|
||||
def coefficient(self, dimension: str) -> float | None:
|
||||
"""
|
||||
Calculate the average coefficient (0-1) for a specific feature dimension.
|
||||
"""Calculate the average coefficient (0-1) for a specific feature dimension.
|
||||
|
||||
This computes the proportion of samples where the feature was evaluated as True,
|
||||
providing a continuous feature value for the classifier.
|
||||
@@ -117,8 +111,7 @@ class FeatureEmbedding(BaseModel):
|
||||
return None
|
||||
|
||||
def to_row(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert the embedding to a flat dictionary suitable for DataFrame construction.
|
||||
"""Convert the embedding to a flat dictionary suitable for DataFrame construction.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Dictionary with metadata fields and feature coefficients.
|
||||
@@ -131,8 +124,7 @@ class FeatureEmbedding(BaseModel):
|
||||
}
|
||||
|
||||
def sample_entropy(self) -> dict[str, float]:
|
||||
"""
|
||||
Calculate the Shannon entropy of feature evaluations across samples.
|
||||
"""Calculate the Shannon entropy of feature evaluations across samples.
|
||||
|
||||
Higher entropy indicates more variability in LLM responses for a feature,
|
||||
which may suggest ambiguity in the feature definition or issue description.
|
||||
@@ -162,8 +154,7 @@ class FeatureEmbedding(BaseModel):
|
||||
|
||||
|
||||
class Featurizer(BaseModel):
|
||||
"""
|
||||
Orchestrates LLM-based feature extraction from issue descriptions.
|
||||
"""Orchestrates LLM-based feature extraction from issue descriptions.
|
||||
|
||||
The Featurizer uses structured LLM tool calling to evaluate boolean features
|
||||
for issue descriptions. It handles prompt construction, tool schema generation,
|
||||
@@ -180,8 +171,7 @@ class Featurizer(BaseModel):
|
||||
"""List of features to extract from each issue description."""
|
||||
|
||||
def system_message(self) -> dict[str, Any]:
|
||||
"""
|
||||
Construct the system message for LLM conversations.
|
||||
"""Construct the system message for LLM conversations.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: System message dictionary for LLM API calls.
|
||||
@@ -194,8 +184,7 @@ class Featurizer(BaseModel):
|
||||
def user_message(
|
||||
self, issue_description: str, set_cache: bool = True
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Construct the user message containing the issue description.
|
||||
"""Construct the user message containing the issue description.
|
||||
|
||||
Args:
|
||||
issue_description: The description of the issue to analyze.
|
||||
@@ -215,8 +204,7 @@ class Featurizer(BaseModel):
|
||||
|
||||
@property
|
||||
def tool_choice(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get the tool choice configuration for forcing LLM to use the featurizer tool.
|
||||
"""Get the tool choice configuration for forcing LLM to use the featurizer tool.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Tool choice configuration for LLM API calls.
|
||||
@@ -228,8 +216,7 @@ class Featurizer(BaseModel):
|
||||
|
||||
@property
|
||||
def tool_description(self) -> dict[str, Any]:
|
||||
"""
|
||||
Generate the tool schema for the featurizer function.
|
||||
"""Generate the tool schema for the featurizer function.
|
||||
|
||||
Creates a JSON schema that describes the featurizer tool with all configured
|
||||
features as boolean parameters.
|
||||
@@ -259,8 +246,7 @@ class Featurizer(BaseModel):
|
||||
temperature: float = 1.0,
|
||||
samples: int = 10,
|
||||
) -> FeatureEmbedding:
|
||||
"""
|
||||
Generate a feature embedding for a single issue description.
|
||||
"""Generate a feature embedding for a single issue description.
|
||||
|
||||
Makes multiple LLM calls to collect samples and reduce variance in feature evaluations.
|
||||
Each call uses tool calling to extract structured boolean feature values.
|
||||
@@ -322,8 +308,7 @@ class Featurizer(BaseModel):
|
||||
temperature: float = 1.0,
|
||||
samples: int = 10,
|
||||
) -> list[FeatureEmbedding]:
|
||||
"""
|
||||
Generate embeddings for a batch of issue descriptions using concurrent processing.
|
||||
"""Generate embeddings for a batch of issue descriptions using concurrent processing.
|
||||
|
||||
Processes multiple issues in parallel to improve throughput while maintaining
|
||||
result ordering.
|
||||
@@ -359,8 +344,7 @@ class Featurizer(BaseModel):
|
||||
return results
|
||||
|
||||
def feature_identifiers(self) -> list[str]:
|
||||
"""
|
||||
Get the identifiers of all configured features.
|
||||
"""Get the identifiers of all configured features.
|
||||
|
||||
Returns:
|
||||
list[str]: List of feature identifiers in the order they were defined.
|
||||
|
||||
@@ -2,8 +2,7 @@ from enum import Enum
|
||||
|
||||
|
||||
class ImportanceStrategy(str, Enum):
|
||||
"""
|
||||
Strategy to use for calculating feature importances, which are used to estimate the predictive power of each feature
|
||||
"""Strategy to use for calculating feature importances, which are used to estimate the predictive power of each feature
|
||||
in training loops and explanations.
|
||||
"""
|
||||
|
||||
|
||||
@@ -6,8 +6,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SolvabilityReport(BaseModel):
|
||||
"""
|
||||
Comprehensive report containing solvability predictions and analysis for a single issue.
|
||||
"""Comprehensive report containing solvability predictions and analysis for a single issue.
|
||||
|
||||
This report includes the solvability score, extracted feature values, feature importance analysis,
|
||||
cost metrics (tokens and latency), and metadata about the prediction process. It serves as the
|
||||
|
||||
@@ -39,13 +39,13 @@ class ResolverViewInterface(SummaryExtractionTracker):
|
||||
raw_payload: dict
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_new_conversation(self, jinja_env: Environment, token: str):
|
||||
"Create a new conversation"
|
||||
"""Create a new conversation"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_callback_id(self) -> str:
|
||||
"Unique callback id for subscribription made to EventStream for fetching agent summary"
|
||||
"""Unique callback id for subscribription made to EventStream for fetching agent summary"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -215,9 +215,7 @@ def get_last_user_msg(event_store: EventStoreABC) -> list[MessageAction]:
|
||||
def extract_summary_from_event_store(
|
||||
event_store: EventStoreABC, conversation_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Get agent summary or alternative message depending on current AgentState
|
||||
"""
|
||||
"""Get agent summary or alternative message depending on current AgentState"""
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
summary_instruction = get_summary_instruction()
|
||||
|
||||
@@ -293,10 +291,7 @@ async def get_last_user_msg_from_conversation_manager(
|
||||
async def extract_summary_from_conversation_manager(
|
||||
conversation_manager: ConversationManager, conversation_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Get agent summary or alternative message depending on current AgentState
|
||||
"""
|
||||
|
||||
"""Get agent summary or alternative message depending on current AgentState"""
|
||||
event_store = await get_event_store_from_conversation_manager(
|
||||
conversation_manager, conversation_id
|
||||
)
|
||||
@@ -305,8 +300,7 @@ async def extract_summary_from_conversation_manager(
|
||||
|
||||
|
||||
def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||
"""
|
||||
Append a small footer with the conversation URL to a message.
|
||||
"""Append a small footer with the conversation URL to a message.
|
||||
|
||||
Args:
|
||||
message: The original message content
|
||||
@@ -321,14 +315,12 @@ def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||
|
||||
|
||||
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||
"""
|
||||
Store repositories in DB and create user-repository mappings
|
||||
"""Store repositories in DB and create user-repository mappings
|
||||
|
||||
Args:
|
||||
repos: List of Repository objects to store
|
||||
user_id: User ID associated with these repositories
|
||||
"""
|
||||
|
||||
# Convert Repository objects to StoredRepository objects
|
||||
# Convert Repository objects to UserRepositoryMap objects
|
||||
stored_repos = []
|
||||
@@ -366,9 +358,9 @@ async def store_repositories_in_db(repos: list[Repository], user_id: str) -> Non
|
||||
|
||||
|
||||
def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
"""
|
||||
Extract all repository names in the format 'owner/repo' from various Git provider URLs
|
||||
"""Extract all repository names in the format 'owner/repo' from various Git provider URLs
|
||||
and direct mentions in text. Supports GitHub, GitLab, and BitBucket.
|
||||
|
||||
Args:
|
||||
user_msg: Input message that may contain repository references
|
||||
Returns:
|
||||
@@ -451,10 +443,10 @@ def filter_potential_repos_by_user_msg(
|
||||
|
||||
|
||||
def markdown_to_jira_markup(markdown_text: str) -> str:
|
||||
"""
|
||||
Convert markdown text to Jira Wiki Markup format.
|
||||
"""Convert markdown text to Jira Wiki Markup format.
|
||||
This function handles common markdown elements and converts them to their
|
||||
Jira Wiki Markup equivalents. It's designed to be exception-safe.
|
||||
|
||||
Args:
|
||||
markdown_text: The markdown text to convert
|
||||
Returns:
|
||||
|
||||
@@ -22,8 +22,7 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""
|
||||
Create maintenance tasks for all users whose user_version is less than
|
||||
"""Create maintenance tasks for all users whose user_version is less than
|
||||
the current version.
|
||||
|
||||
This replaces the functionality of the removed admin maintenance endpoint.
|
||||
@@ -89,8 +88,7 @@ def upgrade():
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""
|
||||
No downgrade operation needed as we're just creating tasks.
|
||||
"""No downgrade operation needed as we're just creating tasks.
|
||||
The tasks themselves will be processed and completed.
|
||||
|
||||
If needed, we could delete tasks with this processor type, but that's not necessary
|
||||
|
||||
@@ -19,7 +19,10 @@ from server.auth.constants import ( # noqa: E402
|
||||
from server.constants import PERMITTED_CORS_ORIGINS # noqa: E402
|
||||
from server.logger import logger # noqa: E402
|
||||
from server.metrics import metrics_app # noqa: E402
|
||||
from server.middleware import SetAuthCookieMiddleware # noqa: E402
|
||||
from server.middleware import ( # noqa: E402
|
||||
LLMSettingsMiddleware,
|
||||
SetAuthCookieMiddleware,
|
||||
)
|
||||
from server.rate_limit import setup_rate_limit_handler # noqa: E402
|
||||
from server.routes.api_keys import api_router as api_keys_router # noqa: E402
|
||||
from server.routes.auth import api_router, oauth_router # noqa: E402
|
||||
@@ -105,6 +108,7 @@ base_app.add_middleware(
|
||||
allow_headers=['*'],
|
||||
)
|
||||
base_app.add_middleware(CacheControlMiddleware)
|
||||
base_app.middleware('http')(LLMSettingsMiddleware())
|
||||
base_app.middleware('http')(SetAuthCookieMiddleware())
|
||||
|
||||
base_app.mount('/', SPAStaticFiles(directory=directory, html=True), name='dist')
|
||||
|
||||
@@ -31,6 +31,7 @@ class GoogleSheetsClient:
|
||||
self, spreadsheet_id: str, range_name: str
|
||||
) -> Optional[List[str]]:
|
||||
"""Get usernames from cache if available and not expired.
|
||||
|
||||
Args:
|
||||
spreadsheet_id: The ID of the Google Sheet
|
||||
range_name: The A1 notation of the range to fetch
|
||||
@@ -56,6 +57,7 @@ class GoogleSheetsClient:
|
||||
self, spreadsheet_id: str, range_name: str, usernames: List[str]
|
||||
) -> None:
|
||||
"""Update cache with new usernames and current timestamp.
|
||||
|
||||
Args:
|
||||
spreadsheet_id: The ID of the Google Sheet
|
||||
range_name: The A1 notation of the range to fetch
|
||||
@@ -67,6 +69,7 @@ class GoogleSheetsClient:
|
||||
def get_usernames(self, spreadsheet_id: str, range_name: str = 'A:A') -> List[str]:
|
||||
"""Get list of usernames from specified Google Sheet.
|
||||
Uses cached data if available and less than 15 seconds old.
|
||||
|
||||
Args:
|
||||
spreadsheet_id: The ID of the Google Sheet
|
||||
range_name: The A1 notation of the range to fetch
|
||||
|
||||
@@ -483,8 +483,7 @@ class ClusteredConversationManager(StandaloneConversationManager):
|
||||
await pipe.execute()
|
||||
|
||||
async def _disconnect_from_stopped(self):
|
||||
"""
|
||||
Handle connections to conversations that have stopped unexpectedly.
|
||||
"""Handle connections to conversations that have stopped unexpectedly.
|
||||
|
||||
This method detects when a local connection is pointing to a conversation
|
||||
that was running on another server that has crashed or been terminated
|
||||
|
||||
@@ -70,8 +70,7 @@ PERMITTED_CORS_ORIGINS = [
|
||||
|
||||
|
||||
def build_litellm_proxy_model_path(model_name: str) -> str:
|
||||
"""
|
||||
Build the LiteLLM proxy model path based on environment and model name.
|
||||
"""Build the LiteLLM proxy model path based on environment and model name.
|
||||
|
||||
This utility constructs the full model path for LiteLLM proxy based on:
|
||||
- Environment type (staging vs prod)
|
||||
@@ -83,7 +82,6 @@ def build_litellm_proxy_model_path(model_name: str) -> str:
|
||||
Returns:
|
||||
The full LiteLLM proxy model path (e.g., 'litellm_proxy/prod/claude-3-7-sonnet-20250219')
|
||||
"""
|
||||
|
||||
if 'prod' in model_name or 'litellm' in model_name or 'proxy' in model_name:
|
||||
raise ValueError("Only include model name, don't include prefix")
|
||||
|
||||
@@ -96,8 +94,7 @@ def build_litellm_proxy_model_path(model_name: str) -> str:
|
||||
|
||||
|
||||
def get_default_litellm_model():
|
||||
"""
|
||||
Construct proxy for litellm model based on user settings and environment type (staging vs prod)
|
||||
"""Construct proxy for litellm model based on user settings and environment type (staging vs prod)
|
||||
if not set explicitly
|
||||
"""
|
||||
if LITELLM_DEFAULT_MODEL:
|
||||
|
||||
@@ -25,8 +25,7 @@ from openhands.server.shared import conversation_manager
|
||||
|
||||
|
||||
class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to GitHub.
|
||||
"""Processor for sending conversation summaries to GitHub.
|
||||
|
||||
This processor is used to send summaries of conversations to GitHub issues/PRs
|
||||
when agent state changes occur.
|
||||
@@ -36,8 +35,7 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
send_summary_instruction: bool = True
|
||||
|
||||
async def _send_message_to_github(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to GitHub.
|
||||
"""Send a message to GitHub.
|
||||
|
||||
Args:
|
||||
message: The message content to send to GitHub
|
||||
@@ -68,8 +66,7 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
callback: ConversationCallback,
|
||||
observation: AgentStateChangedObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Process a conversation event by sending a summary to GitHub.
|
||||
"""Process a conversation event by sending a summary to GitHub.
|
||||
|
||||
Args:
|
||||
callback: The conversation callback
|
||||
|
||||
@@ -28,8 +28,7 @@ gitlab_manager = GitlabManager(token_manager)
|
||||
|
||||
|
||||
class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to GitLab.
|
||||
"""Processor for sending conversation summaries to GitLab.
|
||||
|
||||
This processor is used to send summaries of conversations to GitLab
|
||||
when agent state changes occur.
|
||||
@@ -39,8 +38,7 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
send_summary_instruction: bool = True
|
||||
|
||||
async def _send_message_to_gitlab(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to GitLab.
|
||||
"""Send a message to GitLab.
|
||||
|
||||
Args:
|
||||
message: The message content to send to GitLab
|
||||
@@ -67,8 +65,7 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
callback: ConversationCallback,
|
||||
observation: AgentStateChangedObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Process a conversation event by sending a summary to GitLab.
|
||||
"""Process a conversation event by sending a summary to GitLab.
|
||||
|
||||
Args:
|
||||
callback: The conversation callback
|
||||
|
||||
@@ -26,8 +26,7 @@ integration_store = jira_manager.integration_store
|
||||
|
||||
|
||||
class JiraCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to Jira.
|
||||
"""Processor for sending conversation summaries to Jira.
|
||||
|
||||
This processor is used to send summaries of conversations to Jira issues
|
||||
when agent state changes occur.
|
||||
@@ -37,8 +36,7 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace_name: str
|
||||
|
||||
async def _send_comment_to_jira(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Jira issue.
|
||||
"""Send a comment to Jira issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Jira
|
||||
@@ -79,8 +77,7 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
|
||||
callback: ConversationCallback,
|
||||
observation: AgentStateChangedObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Process a conversation event by sending a summary to Jira.
|
||||
"""Process a conversation event by sending a summary to Jira.
|
||||
|
||||
Args:
|
||||
callback: The conversation callback
|
||||
|
||||
@@ -25,8 +25,7 @@ jira_dc_manager = JiraDcManager(token_manager)
|
||||
|
||||
|
||||
class JiraDcCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to Jira DC.
|
||||
"""Processor for sending conversation summaries to Jira DC.
|
||||
|
||||
This processor is used to send summaries of conversations to Jira DC issues
|
||||
when agent state changes occur.
|
||||
@@ -37,8 +36,7 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
|
||||
base_api_url: str
|
||||
|
||||
async def _send_comment_to_jira_dc(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Jira DC issue.
|
||||
"""Send a comment to Jira DC issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Jira DC
|
||||
@@ -80,8 +78,7 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
|
||||
callback: ConversationCallback,
|
||||
observation: AgentStateChangedObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Process a conversation event by sending a summary to Jira DC.
|
||||
"""Process a conversation event by sending a summary to Jira DC.
|
||||
|
||||
Args:
|
||||
callback: The conversation callback
|
||||
|
||||
@@ -24,8 +24,7 @@ linear_manager = LinearManager(token_manager)
|
||||
|
||||
|
||||
class LinearCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to Linear.
|
||||
"""Processor for sending conversation summaries to Linear.
|
||||
|
||||
This processor is used to send summaries of conversations to Linear issues
|
||||
when agent state changes occur.
|
||||
@@ -36,8 +35,7 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace_name: str
|
||||
|
||||
async def _send_comment_to_linear(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Linear issue.
|
||||
"""Send a comment to Linear issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Linear
|
||||
@@ -79,8 +77,7 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
|
||||
callback: ConversationCallback,
|
||||
observation: AgentStateChangedObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Process a conversation event by sending a summary to Linear.
|
||||
"""Process a conversation event by sending a summary to Linear.
|
||||
|
||||
Args:
|
||||
callback: The conversation callback
|
||||
|
||||
@@ -26,8 +26,7 @@ slack_manager = SlackManager(token_manager)
|
||||
|
||||
|
||||
class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to Slack.
|
||||
"""Processor for sending conversation summaries to Slack.
|
||||
|
||||
This processor is used to send summaries of conversations to Slack channels
|
||||
when agent state changes occur.
|
||||
@@ -41,8 +40,7 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
last_user_msg_id: int | None = None
|
||||
|
||||
async def _send_message_to_slack(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to Slack using the conversation_manager's send_to_event_stream method.
|
||||
"""Send a message to Slack using the conversation_manager's send_to_event_stream method.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Slack
|
||||
@@ -83,8 +81,7 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
callback: ConversationCallback,
|
||||
observation: AgentStateChangedObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Process a conversation event by sending a summary to Slack.
|
||||
"""Process a conversation event by sending a summary to Slack.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation to process
|
||||
|
||||
@@ -33,8 +33,7 @@ class LegacyCacheEntry:
|
||||
|
||||
@dataclass
|
||||
class LegacyConversationManager(ConversationManager):
|
||||
"""
|
||||
Conversation manager for use while migrating - since existing conversations are not nested!
|
||||
"""Conversation manager for use while migrating - since existing conversations are not nested!
|
||||
Separate class from SaasNestedConversationManager so it can be easliy removed in a few weeks.
|
||||
(As of 2025-07-23)
|
||||
"""
|
||||
@@ -270,8 +269,7 @@ class LegacyConversationManager(ConversationManager):
|
||||
del self._legacy_cache[key]
|
||||
|
||||
async def should_start_in_legacy_mode(self, conversation_id: str) -> bool:
|
||||
"""
|
||||
Check if a conversation should run in legacy mode by directly checking the runtime.
|
||||
"""Check if a conversation should run in legacy mode by directly checking the runtime.
|
||||
The /list method does not include stopped conversations even though the PVC for these
|
||||
may not yet have been deleted, so we need to check /sessions/{session_id} directly.
|
||||
"""
|
||||
@@ -295,8 +293,7 @@ class LegacyConversationManager(ConversationManager):
|
||||
return is_legacy
|
||||
|
||||
def is_legacy_runtime(self, runtime: dict | None) -> bool:
|
||||
"""
|
||||
Determine if a runtime is a legacy runtime based on its command.
|
||||
"""Determine if a runtime is a legacy runtime based on its command.
|
||||
|
||||
Args:
|
||||
runtime: The runtime dictionary or None if not found
|
||||
|
||||
@@ -59,11 +59,9 @@ def setup_json_logger(
|
||||
level: str = LOG_LEVEL,
|
||||
_out: TextIO = sys.stdout,
|
||||
) -> None:
|
||||
"""
|
||||
Configure logger instance to output json for Google Cloud.
|
||||
"""Configure logger instance to output json for Google Cloud.
|
||||
Existing filters should stay in place for sensitive content.
|
||||
"""
|
||||
|
||||
# Remove existing handlers to avoid duplicate logs
|
||||
for handler in logger.handlers[:]:
|
||||
logger.removeHandler(handler)
|
||||
@@ -84,8 +82,7 @@ def setup_json_logger(
|
||||
|
||||
|
||||
def setup_all_loggers():
|
||||
"""
|
||||
Setup JSON logging for all libraries that may be logging.
|
||||
"""Setup JSON logging for all libraries that may be logging.
|
||||
Leave OpenHands alone since it's already configured.
|
||||
"""
|
||||
if LOG_JSON:
|
||||
|
||||
@@ -13,8 +13,7 @@ from openhands.core.config import load_openhands_config
|
||||
|
||||
|
||||
class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
|
||||
"""
|
||||
Processor for upgrading user settings to the current version.
|
||||
"""Processor for upgrading user settings to the current version.
|
||||
|
||||
This processor takes a list of user IDs and upgrades any users
|
||||
whose user_version is less than CURRENT_USER_SETTINGS_VERSION.
|
||||
@@ -23,8 +22,7 @@ class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
|
||||
user_ids: List[str]
|
||||
|
||||
async def __call__(self, task: MaintenanceTask) -> dict:
|
||||
"""
|
||||
Process user version upgrades for the specified user IDs.
|
||||
"""Process user version upgrades for the specified user IDs.
|
||||
|
||||
Args:
|
||||
task: The maintenance task being processed
|
||||
|
||||
@@ -27,8 +27,7 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
|
||||
def create_default_mcp_server_config(
|
||||
host: str, config: 'OpenHandsConfig', user_id: str | None = None
|
||||
) -> tuple[MCPSHTTPServerConfig | None, list[MCPStdioServerConfig]]:
|
||||
"""
|
||||
Create a default MCP server configuration.
|
||||
"""Create a default MCP server configuration.
|
||||
|
||||
Args:
|
||||
host: Host string
|
||||
@@ -36,7 +35,6 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
|
||||
Returns:
|
||||
A tuple containing the default SSE server configuration and a list of MCP stdio server configurations
|
||||
"""
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
if user_id:
|
||||
api_key = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
@@ -33,8 +33,7 @@ def metrics_app() -> Callable:
|
||||
metrics_callable = make_asgi_app()
|
||||
|
||||
async def wrapped_handler(scope, receive, send):
|
||||
"""
|
||||
Call _update_metrics before serving Prometheus metrics endpoint.
|
||||
"""Call _update_metrics before serving Prometheus metrics endpoint.
|
||||
Not wrapped in a `try`, failing would make metrics endpoint unavailable.
|
||||
"""
|
||||
await _update_metrics()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import Callable
|
||||
|
||||
import jwt
|
||||
from fastapi import Request, Response, status
|
||||
from fastapi import HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_error import (
|
||||
@@ -12,21 +13,25 @@ from server.auth.auth_error import (
|
||||
)
|
||||
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.saas_user_auth import SaasUserAuth, token_manager
|
||||
from server.constants import get_default_litellm_model
|
||||
from server.routes.auth import (
|
||||
get_cookie_domain,
|
||||
get_cookie_samesite,
|
||||
set_response_cookie,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth.user_auth import AuthType, get_user_auth
|
||||
from openhands.server.utils import config
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class SetAuthCookieMiddleware:
|
||||
"""
|
||||
Update the auth cookie with the current authentication state if it was refreshed before sending response to user.
|
||||
Deleting invalid cookies is handled by CookieError using FastAPIs standard error handling mechanism
|
||||
"""Update the auth cookie with the current authentication state if it was refreshed before sending response to user.
|
||||
|
||||
Deleting invalid cookies is handled by CookieError using FastAPIs standard error handling mechanism.
|
||||
"""
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable):
|
||||
@@ -172,3 +177,247 @@ class SetAuthCookieMiddleware:
|
||||
await token_manager.logout(user_auth.refresh_token.get_secret_value())
|
||||
except Exception:
|
||||
logger.debug('Error logging out')
|
||||
|
||||
|
||||
class LLMSettingsMiddleware:
|
||||
"""Middleware to validate LLM settings access for enterprise users.
|
||||
|
||||
Intercepts POST requests to /api/settings and validates that non-pro users
|
||||
cannot modify LLM-related settings.
|
||||
"""
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable):
|
||||
try:
|
||||
logger.warning(
|
||||
f'LLM middleware called for {request.method} {request.url.path}'
|
||||
)
|
||||
|
||||
# Check if this is a POST request to /api/settings
|
||||
if request.method == 'POST' and request.url.path == '/api/settings':
|
||||
logger.warning('LLM middleware intercepting POST /api/settings request')
|
||||
await self._validate_llm_settings_request(request)
|
||||
|
||||
# Continue with the request
|
||||
response: Response = await call_next(request)
|
||||
return response
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPException (our 403 response)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f'Error in LLM settings middleware: {e}')
|
||||
# Let other errors pass through to be handled by the route
|
||||
fallback_response: Response = await call_next(request)
|
||||
return fallback_response
|
||||
|
||||
async def _validate_llm_settings_request(self, request: Request) -> None:
|
||||
"""Validate LLM settings access for the current request."""
|
||||
try:
|
||||
logger.info(
|
||||
f"LLM settings middleware intercepting POST /api/settings from {request.client.host if request.client else 'unknown'}"
|
||||
)
|
||||
|
||||
# Get user authentication - this will trigger authentication if not already done
|
||||
try:
|
||||
user_auth = await get_user_auth(request)
|
||||
except Exception as e:
|
||||
logger.info(f'No valid user auth found ({e}), letting route handle request')
|
||||
return # No user auth, let the route handle it
|
||||
|
||||
user_id = await user_auth.get_user_id()
|
||||
if not user_id:
|
||||
logger.info('No user ID found, letting route handle request')
|
||||
return # No user ID, let the route handle it
|
||||
|
||||
logger.info(f'Processing settings request for user: {user_id}')
|
||||
|
||||
# Parse the request JSON to get new settings
|
||||
try:
|
||||
settings_data = await request.json()
|
||||
logger.info(f'Parsed settings data keys: {list(settings_data.keys())}')
|
||||
except Exception as e:
|
||||
logger.warning(f'Invalid JSON in request body: {e}')
|
||||
return # Invalid JSON, let the route handle it
|
||||
|
||||
# Convert to Settings object for validation
|
||||
try:
|
||||
new_settings = Settings(**settings_data)
|
||||
logger.info('Successfully created Settings object from request data')
|
||||
except Exception as e:
|
||||
logger.warning(f'Invalid settings format: {e}')
|
||||
return # Invalid settings format, let the route handle it
|
||||
|
||||
# Validate LLM settings access by comparing new settings against SaaS defaults
|
||||
await validate_llm_settings_access(user_id, new_settings)
|
||||
logger.info(f'LLM settings validation passed for user {user_id}')
|
||||
|
||||
except HTTPException as e:
|
||||
logger.warning(
|
||||
f'LLM settings validation failed: HTTP {e.status_code} - {e.detail}'
|
||||
)
|
||||
# Re-raise our 403 response
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f'Unexpected error validating LLM settings request: {e}')
|
||||
# Let other errors pass through
|
||||
|
||||
|
||||
def _get_saas_default_settings() -> Settings:
|
||||
"""Get the default SaaS settings for comparison."""
|
||||
return Settings(
|
||||
language='en',
|
||||
agent='CodeActAgent',
|
||||
enable_proactive_conversation_starters=True,
|
||||
enable_default_condenser=True,
|
||||
condenser_max_size=120,
|
||||
llm_model=get_default_litellm_model(), # litellm_proxy/prod/claude-sonnet-4-20250514
|
||||
confirmation_mode=False,
|
||||
security_analyzer='llm',
|
||||
# Note: llm_api_key and llm_base_url are auto-provisioned for SaaS users,
|
||||
# so we don't include them in defaults - any custom values are changes
|
||||
)
|
||||
|
||||
|
||||
def has_llm_settings_changes(user_settings: Settings, saas_defaults: Settings) -> bool:
|
||||
"""Check if user settings contain changes to LLM-related settings from SaaS defaults."""
|
||||
logger.info(
|
||||
f"Checking LLM settings changes - User settings: {user_settings.model_dump(exclude={'secrets_store'})}"
|
||||
)
|
||||
logger.info(
|
||||
f"Checking LLM settings changes - SaaS defaults: {saas_defaults.model_dump(exclude={'secrets_store'})}"
|
||||
)
|
||||
|
||||
# Core LLM settings - any custom values are changes since SaaS auto-provisions these
|
||||
if (
|
||||
user_settings.llm_model is not None
|
||||
and user_settings.llm_model != saas_defaults.llm_model
|
||||
):
|
||||
logger.warning(
|
||||
f"LLM model change detected: user='{user_settings.llm_model}' vs default='{saas_defaults.llm_model}'"
|
||||
)
|
||||
return True
|
||||
if user_settings.llm_api_key is not None:
|
||||
# Any custom API key is a change (SaaS users get auto-provisioned keys)
|
||||
logger.warning(
|
||||
f'LLM API key change detected: user has custom key (length={len(user_settings.llm_api_key.get_secret_value()) if user_settings.llm_api_key else 0})'
|
||||
)
|
||||
return True
|
||||
if user_settings.llm_base_url is not None and user_settings.llm_base_url != '':
|
||||
# Any non-empty base URL is a change (SaaS users get auto-provisioned URL)
|
||||
logger.warning(
|
||||
f"LLM base URL change detected: user='{user_settings.llm_base_url}' (non-empty)"
|
||||
)
|
||||
return True
|
||||
|
||||
# LLM-related configuration settings
|
||||
if user_settings.agent is not None and user_settings.agent != saas_defaults.agent:
|
||||
logger.warning(
|
||||
f"Agent change detected: user='{user_settings.agent}' vs default='{saas_defaults.agent}'"
|
||||
)
|
||||
return True
|
||||
if (
|
||||
user_settings.confirmation_mode is not None
|
||||
and user_settings.confirmation_mode != saas_defaults.confirmation_mode
|
||||
):
|
||||
logger.warning(
|
||||
f'Confirmation mode change detected: user={user_settings.confirmation_mode} vs default={saas_defaults.confirmation_mode}'
|
||||
)
|
||||
return True
|
||||
if (
|
||||
user_settings.security_analyzer is not None
|
||||
and user_settings.security_analyzer != saas_defaults.security_analyzer
|
||||
and user_settings.security_analyzer != ''
|
||||
): # Handle empty string as None
|
||||
logger.warning(
|
||||
f"Security analyzer change detected: user='{user_settings.security_analyzer}' vs default='{saas_defaults.security_analyzer}'"
|
||||
)
|
||||
return True
|
||||
if user_settings.max_budget_per_task is not None:
|
||||
logger.warning(
|
||||
f'Max budget per task change detected: user={user_settings.max_budget_per_task}'
|
||||
)
|
||||
return True
|
||||
if user_settings.max_iterations is not None:
|
||||
logger.warning(
|
||||
f'Max iterations change detected: user={user_settings.max_iterations}'
|
||||
)
|
||||
return True
|
||||
|
||||
# Memory/context management settings
|
||||
if user_settings.enable_default_condenser != saas_defaults.enable_default_condenser:
|
||||
logger.warning(
|
||||
f'Enable default condenser change detected: user={user_settings.enable_default_condenser} vs default={saas_defaults.enable_default_condenser}'
|
||||
)
|
||||
return True
|
||||
if (
|
||||
user_settings.condenser_max_size is not None
|
||||
and user_settings.condenser_max_size != saas_defaults.condenser_max_size
|
||||
):
|
||||
logger.warning(
|
||||
f'Condenser max size change detected: user={user_settings.condenser_max_size} vs default={saas_defaults.condenser_max_size}'
|
||||
)
|
||||
return True
|
||||
|
||||
logger.info('No LLM settings changes detected')
|
||||
return False
|
||||
|
||||
|
||||
def _has_active_subscription(user_id: str) -> bool:
|
||||
"""Check if user has an active subscription (pro user)."""
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
logger.info(f'Checking subscription for user {user_id} at time {now}')
|
||||
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access:
|
||||
logger.info(
|
||||
f'Found active subscription for user {user_id}: starts={subscription_access.start_at}, ends={subscription_access.end_at}'
|
||||
)
|
||||
else:
|
||||
logger.info(f'No active subscription found for user {user_id}')
|
||||
|
||||
return subscription_access is not None
|
||||
|
||||
|
||||
async def validate_llm_settings_access(
|
||||
user_id: str, user_settings: Settings, saas_defaults: Settings | None = None
|
||||
) -> None:
|
||||
"""Validate that user has permission to change LLM settings.
|
||||
|
||||
Raises HTTPException with 403 status if non-pro user tries to change LLM settings.
|
||||
"""
|
||||
if saas_defaults is None:
|
||||
saas_defaults = _get_saas_default_settings()
|
||||
|
||||
logger.info(f'Validating LLM settings access for user: {user_id}')
|
||||
|
||||
# Check if user is trying to change LLM settings
|
||||
if has_llm_settings_changes(user_settings, saas_defaults):
|
||||
logger.warning(f'User {user_id} attempting to change LLM settings')
|
||||
|
||||
# Check if user has active subscription (is pro user)
|
||||
has_subscription = _has_active_subscription(user_id)
|
||||
logger.info(
|
||||
f"User {user_id} subscription status: {'active' if has_subscription else 'none'}"
|
||||
)
|
||||
|
||||
if not has_subscription:
|
||||
logger.warning(
|
||||
f'Blocking non-pro user {user_id} from changing LLM settings'
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='LLM settings can only be modified by pro users',
|
||||
)
|
||||
else:
|
||||
logger.info(f'Allowing pro user {user_id} to change LLM settings')
|
||||
else:
|
||||
logger.info(f'User {user_id} making non-LLM settings changes only - allowing')
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Usage:
|
||||
"""Usage:
|
||||
|
||||
Call setup_rate_limit_handler on your FastAPI app to add the exception handler
|
||||
|
||||
@@ -23,9 +22,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def setup_rate_limit_handler(app: Starlette):
|
||||
"""
|
||||
Add exception handler that
|
||||
"""
|
||||
"""Add exception handler that"""
|
||||
app.add_exception_handler(RateLimitException, _rate_limit_exceeded_handler)
|
||||
|
||||
|
||||
@@ -56,8 +53,7 @@ class RateLimiter:
|
||||
self.limit_items = limits.parse_many(windows)
|
||||
|
||||
async def hit(self, namespace: str, key: str):
|
||||
"""
|
||||
Raises RateLimitException when limit is hit.
|
||||
"""Raises RateLimitException when limit is hit.
|
||||
Logs and swallows exceptions and logs if lookup fails.
|
||||
"""
|
||||
for lim in self.limit_items:
|
||||
@@ -80,9 +76,7 @@ class RateLimiter:
|
||||
async def _get_stats_as_result(
|
||||
self, lim: limits.RateLimitItem, namespace: str, key: str
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Lookup rate limit window stats and return a RateLimitResult with the data needed for response headers.
|
||||
"""
|
||||
"""Lookup rate limit window stats and return a RateLimitResult with the data needed for response headers."""
|
||||
stats: limits.WindowStats = await self.strategy.get_window_stats(
|
||||
lim, namespace, key
|
||||
)
|
||||
@@ -97,8 +91,7 @@ class RateLimiter:
|
||||
|
||||
|
||||
def create_redis_rate_limiter(windows: str) -> RateLimiter:
|
||||
"""
|
||||
Create a RateLimiter with the Redis backend and "Fixed Window" strategy.
|
||||
"""Create a RateLimiter with the Redis backend and "Fixed Window" strategy.
|
||||
windows arg example: "10/second; 100/minute"
|
||||
"""
|
||||
backend = limits.aio.storage.RedisStorage(f'async+{get_redis_authed_url()}')
|
||||
@@ -107,9 +100,7 @@ def create_redis_rate_limiter(windows: str) -> RateLimiter:
|
||||
|
||||
|
||||
class RateLimitException(HTTPException):
|
||||
"""
|
||||
exception raised when a rate limit is hit.
|
||||
"""
|
||||
"""exception raised when a rate limit is hit."""
|
||||
|
||||
result: RateLimitResult
|
||||
|
||||
@@ -121,9 +112,7 @@ class RateLimitException(HTTPException):
|
||||
|
||||
|
||||
def _rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
|
||||
"""
|
||||
Build a simple JSON response that includes the details of the rate limit that was hit.
|
||||
"""
|
||||
"""Build a simple JSON response that includes the details of the rate limit that was hit."""
|
||||
logger.info(exc.__class__.__name__)
|
||||
if isinstance(exc, RateLimitException):
|
||||
response = JSONResponse(
|
||||
|
||||
@@ -17,8 +17,7 @@ ADD_DEBUGGING_ROUTES = os.environ.get('ADD_DEBUGGING_ROUTES') in ('1', 'true')
|
||||
|
||||
|
||||
def add_debugging_routes(api: FastAPI):
|
||||
"""
|
||||
# HERE BE DRAGONS!
|
||||
"""# HERE BE DRAGONS!
|
||||
Chaos scripts for debugging and stress testing the system.
|
||||
|
||||
This module contains endpoints that deliberately stress test and potentially break
|
||||
@@ -31,7 +30,6 @@ def add_debugging_routes(api: FastAPI):
|
||||
- Testing async vs sync database access patterns
|
||||
- Simulating event loop blocking
|
||||
"""
|
||||
|
||||
if not ADD_DEBUGGING_ROUTES:
|
||||
return
|
||||
|
||||
@@ -39,8 +37,7 @@ def add_debugging_routes(api: FastAPI):
|
||||
|
||||
@chaos_router.get('/pool-stats')
|
||||
def pool_stats() -> dict[str, int]:
|
||||
"""
|
||||
Returns current database connection pool statistics.
|
||||
"""Returns current database connection pool statistics.
|
||||
|
||||
This endpoint provides real-time metrics about the SQLAlchemy connection pool:
|
||||
- checked_in: Number of connections currently available in the pool
|
||||
@@ -55,8 +52,7 @@ def add_debugging_routes(api: FastAPI):
|
||||
|
||||
@chaos_router.get('/test-db')
|
||||
def test_db(num_tests: int = 10, delay: int = 1) -> str:
|
||||
"""
|
||||
Stress tests the database connection pool using multiple threads.
|
||||
"""Stress tests the database connection pool using multiple threads.
|
||||
|
||||
Creates multiple threads that each open a database connection, perform a query,
|
||||
hold the connection for the specified delay, and then release it.
|
||||
@@ -77,8 +73,7 @@ def add_debugging_routes(api: FastAPI):
|
||||
|
||||
@chaos_router.get('/a-test-db')
|
||||
async def a_chaos_monkey(num_tests: int = 10, delay: int = 1) -> str:
|
||||
"""
|
||||
Stress tests the async database connection pool.
|
||||
"""Stress tests the async database connection pool.
|
||||
|
||||
Similar to /test-db but uses async connections and coroutines instead of threads.
|
||||
This endpoint helps compare the behavior of async vs sync connection pools
|
||||
@@ -93,8 +88,7 @@ def add_debugging_routes(api: FastAPI):
|
||||
|
||||
@chaos_router.get('/lock-main-runloop')
|
||||
async def lock_main_runloop(duration: int = 10) -> str:
|
||||
"""
|
||||
Deliberately blocks the main asyncio event loop.
|
||||
"""Deliberately blocks the main asyncio event loop.
|
||||
|
||||
This endpoint uses a synchronous sleep operation in an async function,
|
||||
which blocks the entire FastAPI server's event loop for the specified duration.
|
||||
@@ -113,8 +107,7 @@ def add_debugging_routes(api: FastAPI):
|
||||
|
||||
|
||||
def _db_check(delay: int):
|
||||
"""
|
||||
Executes a single request against the database with an artificial delay.
|
||||
"""Executes a single request against the database with an artificial delay.
|
||||
|
||||
This helper function:
|
||||
1. Opens a database connection from the pool
|
||||
@@ -141,8 +134,7 @@ def _db_check(delay: int):
|
||||
|
||||
|
||||
async def _a_db_check(delay: int):
|
||||
"""
|
||||
Executes a single async request against the database with an artificial delay.
|
||||
"""Executes a single async request against the database with an artificial delay.
|
||||
|
||||
This is the async version of _db_check that:
|
||||
1. Opens an async database connection from the pool
|
||||
|
||||
@@ -73,8 +73,7 @@ class FeedbackRequest(BaseModel):
|
||||
|
||||
@router.post('/conversation', status_code=status.HTTP_201_CREATED)
|
||||
async def submit_conversation_feedback(feedback: FeedbackRequest):
|
||||
"""
|
||||
Submit feedback for a conversation.
|
||||
"""Submit feedback for a conversation.
|
||||
|
||||
This endpoint accepts a rating (1-5) and optional reason for the feedback.
|
||||
The feedback is associated with a specific conversation and optionally a specific event.
|
||||
@@ -108,8 +107,7 @@ async def submit_conversation_feedback(feedback: FeedbackRequest):
|
||||
|
||||
@router.get('/conversation/{conversation_id}/batch')
|
||||
async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_user_id)):
|
||||
"""
|
||||
Get feedback for all events in a conversation.
|
||||
"""Get feedback for all events in a conversation.
|
||||
|
||||
Returns feedback status for each event, including whether feedback exists
|
||||
and if so, the rating and reason.
|
||||
|
||||
@@ -16,8 +16,7 @@ GITHUB_PROXY_ENDPOINTS = bool(os.environ.get('GITHUB_PROXY_ENDPOINTS'))
|
||||
|
||||
|
||||
def add_github_proxy_routes(app: FastAPI):
|
||||
"""
|
||||
Authentication endpoints for feature branches.
|
||||
"""Authentication endpoints for feature branches.
|
||||
|
||||
# Requirements
|
||||
* This should never be enabled in prod!
|
||||
|
||||
@@ -23,14 +23,10 @@ AGENT_SESSION_START_HISTOGRAM = Histogram(
|
||||
|
||||
|
||||
class SaaSMonitoringListener(MonitoringListener):
|
||||
"""
|
||||
Forward app signals to Prometheus.
|
||||
"""
|
||||
"""Forward app signals to Prometheus."""
|
||||
|
||||
def on_session_event(self, event: Event) -> None:
|
||||
"""
|
||||
Track metrics about events being added to a Session's EventStream.
|
||||
"""
|
||||
"""Track metrics about events being added to a Session's EventStream."""
|
||||
if (
|
||||
isinstance(event, AgentStateChangedObservation)
|
||||
and event.agent_state == AgentState.ERROR
|
||||
@@ -42,8 +38,7 @@ class SaaSMonitoringListener(MonitoringListener):
|
||||
)
|
||||
|
||||
def on_agent_session_start(self, success: bool, duration: float) -> None:
|
||||
"""
|
||||
Track an agent session start.
|
||||
"""Track an agent session start.
|
||||
Success is true if startup completed without error.
|
||||
Duration is start time in seconds observed by AgentSession.
|
||||
"""
|
||||
@@ -58,8 +53,7 @@ class SaaSMonitoringListener(MonitoringListener):
|
||||
)
|
||||
|
||||
def on_create_conversation(self) -> None:
|
||||
"""
|
||||
Track the beginning of conversation creation.
|
||||
"""Track the beginning of conversation creation.
|
||||
Does not currently capture whether it succeed.
|
||||
"""
|
||||
CREATE_CONVERSATION_COUNT.inc()
|
||||
|
||||
@@ -131,9 +131,7 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
async def get_running_agent_loops(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get the running agent loops directly from the remote runtime.
|
||||
"""
|
||||
"""Get the running agent loops directly from the remote runtime."""
|
||||
conversation_ids = await self._get_all_running_conversation_ids()
|
||||
|
||||
if filter_to_sids is not None:
|
||||
@@ -482,10 +480,7 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
)
|
||||
|
||||
def _get_user_id_from_conversation(self, conversation_id: str) -> str:
|
||||
"""
|
||||
Get user_id from conversation_id.
|
||||
"""
|
||||
|
||||
"""Get user_id from conversation_id."""
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
|
||||
@@ -34,8 +34,7 @@ file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
async def process_event(
|
||||
user_id: str, conversation_id: str, subpath: str, content: dict
|
||||
):
|
||||
"""
|
||||
Process a conversation event and invoke any registered callbacks.
|
||||
"""Process a conversation event and invoke any registered callbacks.
|
||||
|
||||
Args:
|
||||
user_id: The user ID associated with the conversation
|
||||
@@ -72,8 +71,7 @@ async def process_event(
|
||||
async def invoke_conversation_callbacks(
|
||||
conversation_id: str, observation: AgentStateChangedObservation
|
||||
):
|
||||
"""
|
||||
Load and invoke all active callbacks for a conversation.
|
||||
"""Load and invoke all active callbacks for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID to process callbacks for
|
||||
@@ -119,8 +117,7 @@ async def invoke_conversation_callbacks(
|
||||
|
||||
|
||||
def update_conversation_metadata(conversation_id: str, content: dict):
|
||||
"""
|
||||
Update conversation metadata with new content.
|
||||
"""Update conversation metadata with new content.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID to update
|
||||
@@ -159,8 +156,7 @@ def update_conversation_metadata(conversation_id: str, content: dict):
|
||||
def register_callback_processor(
|
||||
conversation_id: str, processor: ConversationCallbackProcessor
|
||||
) -> int:
|
||||
"""
|
||||
Register a callback processor for a conversation.
|
||||
"""Register a callback processor for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID to register the callback for
|
||||
@@ -182,8 +178,7 @@ def register_callback_processor(
|
||||
def update_active_working_seconds(
|
||||
event_store: EventStore, conversation_id: str, user_id: str, file_store: FileStore
|
||||
):
|
||||
"""
|
||||
Calculate and update the total active working seconds for a conversation.
|
||||
"""Calculate and update the total active working seconds for a conversation.
|
||||
|
||||
This function reads all events for the conversation, looks for AgentStateChanged
|
||||
observations, and calculates the total time spent in a running state.
|
||||
@@ -263,8 +258,7 @@ def update_active_working_seconds(
|
||||
|
||||
|
||||
def update_agent_state(user_id: str, conversation_id: str, content: bytes):
|
||||
"""
|
||||
Update agent state file for a conversation.
|
||||
"""Update agent state file for a conversation.
|
||||
|
||||
Args:
|
||||
user_id: The user ID associated with the conversation
|
||||
|
||||
@@ -3,9 +3,7 @@ from storage.base import Base
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
"""
|
||||
Represents an API key for a user.
|
||||
"""
|
||||
"""Represents an API key for a user."""
|
||||
|
||||
__tablename__ = 'api_keys'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
@@ -73,8 +73,7 @@ class AuthTokenStore:
|
||||
]
|
||||
| None = None,
|
||||
) -> Dict[str, str | int] | None:
|
||||
"""
|
||||
Load authentication tokens from the database and refresh them if necessary.
|
||||
"""Load authentication tokens from the database and refresh them if necessary.
|
||||
|
||||
This method retrieves the current authentication tokens for the user and checks if they have expired.
|
||||
It uses the provided `check_expiration_and_refresh` function to determine if the tokens need
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Unified SQLAlchemy declarative base for all models.
|
||||
"""
|
||||
"""Unified SQLAlchemy declarative base for all models."""
|
||||
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ from storage.base import Base
|
||||
|
||||
|
||||
class BillingSession(Base): # type: ignore
|
||||
"""
|
||||
Represents a Stripe billing session for credit purchases.
|
||||
"""Represents a Stripe billing session for credit purchases.
|
||||
Tracks the status of payment transactions and associated user information.
|
||||
"""
|
||||
|
||||
|
||||
@@ -15,8 +15,7 @@ from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class ConversationCallbackProcessor(BaseModel, ABC):
|
||||
"""
|
||||
Abstract base class for conversation callback processors.
|
||||
"""Abstract base class for conversation callback processors.
|
||||
|
||||
Conversation processors are invoked when events occur in a conversation
|
||||
to perform additional processing, notifications, or integrations.
|
||||
@@ -35,8 +34,7 @@ class ConversationCallbackProcessor(BaseModel, ABC):
|
||||
callback: ConversationCallback,
|
||||
observation: AgentStateChangedObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Process a conversation event.
|
||||
"""Process a conversation event.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation to process
|
||||
@@ -54,9 +52,7 @@ class CallbackStatus(Enum):
|
||||
|
||||
|
||||
class ConversationCallback(Base): # type: ignore
|
||||
"""
|
||||
Model for storing conversation callbacks that process conversation events.
|
||||
"""
|
||||
"""Model for storing conversation callbacks that process conversation events."""
|
||||
|
||||
__tablename__ = 'conversation_callbacks'
|
||||
|
||||
@@ -85,8 +81,7 @@ class ConversationCallback(Base): # type: ignore
|
||||
)
|
||||
|
||||
def get_processor(self) -> ConversationCallbackProcessor:
|
||||
"""
|
||||
Get the processor instance from the stored processor type and JSON data.
|
||||
"""Get the processor instance from the stored processor type and JSON data.
|
||||
|
||||
Returns:
|
||||
ConversationCallbackProcessor: The processor instance
|
||||
@@ -99,8 +94,7 @@ class ConversationCallback(Base): # type: ignore
|
||||
return processor
|
||||
|
||||
def set_processor(self, processor: ConversationCallbackProcessor) -> None:
|
||||
"""
|
||||
Set the processor instance, storing its type and JSON representation.
|
||||
"""Set the processor instance, storing its type and JSON representation.
|
||||
|
||||
Args:
|
||||
processor: The ConversationCallbackProcessor instance to store
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Database model for experiment assignments.
|
||||
"""Database model for experiment assignments.
|
||||
|
||||
This model tracks which experiments a conversation is assigned to and what variant
|
||||
they received from PostHog feature flags.
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Store for managing experiment assignments.
|
||||
"""Store for managing experiment assignments.
|
||||
|
||||
This store handles creating and updating experiment assignments for conversations.
|
||||
"""
|
||||
@@ -20,8 +19,7 @@ class ExperimentAssignmentStore:
|
||||
experiment_name: str,
|
||||
variant: str,
|
||||
) -> None:
|
||||
"""
|
||||
Update the variant for a specific experiment.
|
||||
"""Update the variant for a specific experiment.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation ID
|
||||
|
||||
@@ -3,9 +3,7 @@ from storage.base import Base
|
||||
|
||||
|
||||
class GithubAppInstallation(Base): # type: ignore
|
||||
"""
|
||||
Represents a Github App Installation with associated token.
|
||||
"""
|
||||
"""Represents a Github App Installation with associated token."""
|
||||
|
||||
__tablename__ = 'github_app_installations'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
@@ -13,9 +13,7 @@ class WebhookStatus(IntEnum):
|
||||
|
||||
|
||||
class GitlabWebhook(Base): # type: ignore
|
||||
"""
|
||||
Represents a Gitlab webhook configuration for a repository or group.
|
||||
"""
|
||||
"""Represents a Gitlab webhook configuration for a repository or group."""
|
||||
|
||||
__tablename__ = 'gitlab_webhook'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
@@ -86,7 +86,6 @@ class GitlabWebhookStore:
|
||||
Raises:
|
||||
ValueError: If neither project_id nor group_id is provided, or if both are provided.
|
||||
"""
|
||||
|
||||
resource_type, resource_id = GitlabWebhookStore.determine_resource_type(webhook)
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin():
|
||||
@@ -110,7 +109,6 @@ class GitlabWebhookStore:
|
||||
Raises:
|
||||
ValueError: If neither project_id nor group_id is provided, or if both are provided.
|
||||
"""
|
||||
|
||||
resource_type, resource_id = GitlabWebhookStore.determine_resource_type(webhook)
|
||||
|
||||
logger.info(
|
||||
@@ -184,7 +182,6 @@ class GitlabWebhookStore:
|
||||
Returns:
|
||||
List of GitlabWebhook objects that need processing
|
||||
"""
|
||||
|
||||
async with self.a_session_maker() as session:
|
||||
query = (
|
||||
select(GitlabWebhook)
|
||||
@@ -198,9 +195,7 @@ class GitlabWebhookStore:
|
||||
return list(webhooks)
|
||||
|
||||
async def get_webhook_secret(self, webhook_uuid: str, user_id: str) -> str | None:
|
||||
"""
|
||||
Get's webhook secret given the webhook uuid and admin keycloak user id
|
||||
"""
|
||||
"""Get's webhook secret given the webhook uuid and admin keycloak user id"""
|
||||
async with self.a_session_maker() as session:
|
||||
query = (
|
||||
select(GitlabWebhook)
|
||||
|
||||
@@ -23,7 +23,6 @@ class JiraDcIntegrationStore:
|
||||
status: str = 'active',
|
||||
) -> JiraDcWorkspace:
|
||||
"""Create a new Jira DC workspace with encrypted sensitive data."""
|
||||
|
||||
with session_maker() as session:
|
||||
workspace = JiraDcWorkspace(
|
||||
name=name.lower(),
|
||||
@@ -83,7 +82,6 @@ class JiraDcIntegrationStore:
|
||||
status: str = 'active',
|
||||
) -> JiraDcUser:
|
||||
"""Create a new Jira DC workspace link."""
|
||||
|
||||
jira_dc_user = JiraDcUser(
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
jira_dc_user_id=jira_dc_user_id,
|
||||
@@ -125,7 +123,6 @@ class JiraDcIntegrationStore:
|
||||
self, keycloak_user_id: str
|
||||
) -> Optional[JiraDcUser]:
|
||||
"""Retrieve user by Keycloak user ID."""
|
||||
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(JiraDcUser)
|
||||
@@ -184,7 +181,6 @@ class JiraDcIntegrationStore:
|
||||
self, keycloak_user_id: str, status: str
|
||||
) -> JiraDcUser:
|
||||
"""Update the status of a Jira DC user mapping."""
|
||||
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(JiraDcUser)
|
||||
|
||||
@@ -24,7 +24,6 @@ class JiraIntegrationStore:
|
||||
status: str = 'active',
|
||||
) -> JiraWorkspace:
|
||||
"""Create a new Jira workspace with encrypted sensitive data."""
|
||||
|
||||
workspace = JiraWorkspace(
|
||||
name=name.lower(),
|
||||
jira_cloud_id=jira_cloud_id,
|
||||
@@ -91,7 +90,6 @@ class JiraIntegrationStore:
|
||||
status: str = 'active',
|
||||
) -> JiraUser:
|
||||
"""Create a new Jira workspace link."""
|
||||
|
||||
jira_user = JiraUser(
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
jira_user_id=jira_user_id,
|
||||
|
||||
@@ -24,7 +24,6 @@ class LinearIntegrationStore:
|
||||
status: str = 'active',
|
||||
) -> LinearWorkspace:
|
||||
"""Create a new Linear workspace with encrypted sensitive data."""
|
||||
|
||||
workspace = LinearWorkspace(
|
||||
name=name.lower(),
|
||||
linear_org_id=linear_org_id,
|
||||
|
||||
@@ -15,8 +15,7 @@ from openhands.utils.import_utils import get_impl
|
||||
|
||||
|
||||
class MaintenanceTaskProcessor(BaseModel, ABC):
|
||||
"""
|
||||
Abstract base class for maintenance task processors.
|
||||
"""Abstract base class for maintenance task processors.
|
||||
|
||||
Maintenance processors are invoked to perform background maintenance
|
||||
tasks such as upgrading user settings, cleaning up data, etc.
|
||||
@@ -31,8 +30,7 @@ class MaintenanceTaskProcessor(BaseModel, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, task: MaintenanceTask) -> dict:
|
||||
"""
|
||||
Process a maintenance task.
|
||||
"""Process a maintenance task.
|
||||
|
||||
Args:
|
||||
task: The maintenance task to process
|
||||
@@ -53,9 +51,7 @@ class MaintenanceTaskStatus(Enum):
|
||||
|
||||
|
||||
class MaintenanceTask(Base): # type: ignore
|
||||
"""
|
||||
Model for storing maintenance tasks that perform background operations.
|
||||
"""
|
||||
"""Model for storing maintenance tasks that perform background operations."""
|
||||
|
||||
__tablename__ = 'maintenance_tasks'
|
||||
|
||||
@@ -83,8 +79,7 @@ class MaintenanceTask(Base): # type: ignore
|
||||
)
|
||||
|
||||
def get_processor(self) -> MaintenanceTaskProcessor:
|
||||
"""
|
||||
Get the processor instance from the stored processor type and JSON data.
|
||||
"""Get the processor instance from the stored processor type and JSON data.
|
||||
|
||||
Returns:
|
||||
MaintenanceTaskProcessor: The processor instance
|
||||
@@ -97,8 +92,7 @@ class MaintenanceTask(Base): # type: ignore
|
||||
return processor
|
||||
|
||||
def set_processor(self, processor: MaintenanceTaskProcessor) -> None:
|
||||
"""
|
||||
Set the processor instance, storing its type and JSON representation.
|
||||
"""Set the processor instance, storing its type and JSON representation.
|
||||
|
||||
Args:
|
||||
processor: The MaintenanceTaskProcessor instance to store
|
||||
|
||||
@@ -13,9 +13,7 @@ from storage.base import Base
|
||||
|
||||
|
||||
class OpenhandsPR(Base): # type: ignore
|
||||
"""
|
||||
Represents a pull request created by OpenHands.
|
||||
"""
|
||||
"""Represents a pull request created by OpenHands."""
|
||||
|
||||
__tablename__ = 'openhands_prs'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
|
||||
@@ -15,9 +15,7 @@ class OpenhandsPRStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
def insert_pr(self, pr: OpenhandsPR) -> None:
|
||||
"""
|
||||
Insert a new PR or delete and recreate if repo_id and pr_number already exist.
|
||||
"""
|
||||
"""Insert a new PR or delete and recreate if repo_id and pr_number already exist."""
|
||||
with self.session_maker() as session:
|
||||
# Check if PR already exists
|
||||
existing_pr = (
|
||||
@@ -39,8 +37,7 @@ class OpenhandsPRStore:
|
||||
session.commit()
|
||||
|
||||
def increment_process_attempts(self, repo_id: str, pr_number: int) -> bool:
|
||||
"""
|
||||
Increment the process attempts counter for a PR.
|
||||
"""Increment the process attempts counter for a PR.
|
||||
|
||||
Args:
|
||||
repo_id: Repository identifier
|
||||
@@ -75,8 +72,7 @@ class OpenhandsPRStore:
|
||||
num_openhands_review_comments: int,
|
||||
num_openhands_general_comments: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Update OpenHands statistics for a PR with row-level locking and timestamp validation.
|
||||
"""Update OpenHands statistics for a PR with row-level locking and timestamp validation.
|
||||
|
||||
Args:
|
||||
repo_id: Repository identifier
|
||||
@@ -126,8 +122,7 @@ class OpenhandsPRStore:
|
||||
def get_unprocessed_prs(
|
||||
self, limit: int = 50, max_retries: int = 3
|
||||
) -> list[OpenhandsPR]:
|
||||
"""
|
||||
Get unprocessed PR entries from the OpenhandsPR table.
|
||||
"""Get unprocessed PR entries from the OpenhandsPR table.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of PRs to retrieve (default: 50)
|
||||
|
||||
@@ -34,8 +34,7 @@ class ProactiveConversationStore:
|
||||
pr_number: int,
|
||||
get_all_workflows: Callable,
|
||||
) -> WorkflowRunGroup | None:
|
||||
"""
|
||||
1. Get the workflow based on repo_id, pr_number, commit
|
||||
"""1. Get the workflow based on repo_id, pr_number, commit
|
||||
2. If the field doesn't exist
|
||||
- Fetch the workflow statuses and store them
|
||||
- Create a new record
|
||||
@@ -45,7 +44,6 @@ class ProactiveConversationStore:
|
||||
This method uses an explicit transaction with row-level locking to ensure
|
||||
thread safety when multiple processes access the same database rows.
|
||||
"""
|
||||
|
||||
should_send = False
|
||||
provider_repo_id = self.get_repo_id(provider, repo_id)
|
||||
|
||||
@@ -131,14 +129,12 @@ class ProactiveConversationStore:
|
||||
return final_workflow_group
|
||||
|
||||
async def clean_old_convos(self, older_than_minutes=30):
|
||||
"""
|
||||
Clean up proactive conversation records that are older than the specified time.
|
||||
"""Clean up proactive conversation records that are older than the specified time.
|
||||
|
||||
Args:
|
||||
older_than_minutes: Number of minutes. Records older than this will be deleted.
|
||||
Defaults to 30 minutes.
|
||||
"""
|
||||
|
||||
# Calculate the cutoff time (current time - older_than_minutes)
|
||||
cutoff_time = datetime.now(UTC) - timedelta(minutes=older_than_minutes)
|
||||
|
||||
|
||||
@@ -15,8 +15,7 @@ class RepositoryStore:
|
||||
config: OpenHandsConfig
|
||||
|
||||
def store_projects(self, repositories: list[StoredRepository]) -> None:
|
||||
"""
|
||||
Store repositories in database
|
||||
"""Store repositories in database
|
||||
|
||||
1. Make sure to store repositories if its ID doesn't exist
|
||||
2. If repository ID already exists, make sure to only update the repo is_public and repo_name fields
|
||||
|
||||
@@ -14,8 +14,7 @@ class SaasConversationValidator(ConversationValidator):
|
||||
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
||||
|
||||
async def _validate_api_key(self, api_key: str) -> str | None:
|
||||
"""
|
||||
Validate an API key and return the user_id and github_user_id if valid.
|
||||
"""Validate an API key and return the user_id and github_user_id if valid.
|
||||
|
||||
Args:
|
||||
api_key: The API key to validate
|
||||
@@ -49,8 +48,7 @@ class SaasConversationValidator(ConversationValidator):
|
||||
async def _validate_conversation_access(
|
||||
self, conversation_id: str, user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that the user has access to the conversation.
|
||||
"""Validate that the user has access to the conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
@@ -81,8 +79,7 @@ class SaasConversationValidator(ConversationValidator):
|
||||
cookies_str: str,
|
||||
authorization_header: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Validate the conversation access using either an API key from the Authorization header
|
||||
"""Validate the conversation access using either an API key from the Authorization header
|
||||
or a keycloak_auth cookie.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -14,8 +14,7 @@ class SlackConversationStore:
|
||||
async def get_slack_conversation(
|
||||
self, channel_id: str, parent_id: str
|
||||
) -> SlackConversation | None:
|
||||
"""
|
||||
Get a slack conversation by channel_id and message_ts.
|
||||
"""Get a slack conversation by channel_id and message_ts.
|
||||
Both parameters are required to match for a conversation to be returned.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
|
||||
@@ -10,9 +10,7 @@ class SlackTeamStore:
|
||||
session_maker: sessionmaker
|
||||
|
||||
def get_team_bot_token(self, team_id: str) -> str | None:
|
||||
"""
|
||||
Get a team's bot access token by team_id
|
||||
"""
|
||||
"""Get a team's bot access token by team_id"""
|
||||
with session_maker() as session:
|
||||
team = session.query(SlackTeam).filter(SlackTeam.team_id == team_id).first()
|
||||
return team.bot_access_token if team else None
|
||||
@@ -22,9 +20,7 @@ class SlackTeamStore:
|
||||
team_id: str,
|
||||
bot_access_token: str,
|
||||
) -> SlackTeam:
|
||||
"""
|
||||
Create a new SlackTeam
|
||||
"""
|
||||
"""Create a new SlackTeam"""
|
||||
slack_team = SlackTeam(team_id=team_id, bot_access_token=bot_access_token)
|
||||
with session_maker() as session:
|
||||
session.query(SlackTeam).filter(SlackTeam.team_id == team_id).delete()
|
||||
|
||||
@@ -3,9 +3,7 @@ from storage.base import Base
|
||||
|
||||
|
||||
class StoredRepository(Base): # type: ignore
|
||||
"""
|
||||
Represents a repositories fetched from git providers.
|
||||
"""
|
||||
"""Represents a repositories fetched from git providers."""
|
||||
|
||||
__tablename__ = 'repos'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
@@ -5,9 +5,7 @@ from storage.base import Base
|
||||
|
||||
|
||||
class StoredSettings(Base): # type: ignore
|
||||
"""
|
||||
Legacy user settings storage. This should be considered deprecated - use UserSettings isntead
|
||||
"""
|
||||
"""Legacy user settings storage. This should be considered deprecated - use UserSettings isntead"""
|
||||
|
||||
__tablename__ = 'settings'
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
|
||||
@@ -3,8 +3,7 @@ from storage.base import Base
|
||||
|
||||
|
||||
class StripeCustomer(Base): # type: ignore
|
||||
"""
|
||||
Represents a stripe customer. We can't simply use the stripe API for this because:
|
||||
"""Represents a stripe customer. We can't simply use the stripe API for this because:
|
||||
"Don’t use search in read-after-write flows where strict consistency is necessary.
|
||||
Under normal operating conditions, data is searchable in less than a minute.
|
||||
Occasionally, propagation of new or updated data can be up to an hour behind during outages"
|
||||
|
||||
@@ -5,8 +5,7 @@ from storage.base import Base
|
||||
|
||||
|
||||
class SubscriptionAccess(Base): # type: ignore
|
||||
"""
|
||||
Represents a user's subscription access record.
|
||||
"""Represents a user's subscription access record.
|
||||
Tracks subscription status, duration, payment information, and cancellation status.
|
||||
"""
|
||||
|
||||
|
||||
@@ -3,9 +3,7 @@ from storage.base import Base
|
||||
|
||||
|
||||
class UserRepositoryMap(Base):
|
||||
"""
|
||||
Represents a map between user id and repo ids
|
||||
"""
|
||||
"""Represents a map between user id and repo ids"""
|
||||
|
||||
__tablename__ = 'user-repos'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
@@ -16,8 +16,7 @@ class UserRepositoryMapStore:
|
||||
config: OpenHandsConfig
|
||||
|
||||
def store_user_repo_mappings(self, mappings: list[UserRepositoryMap]) -> None:
|
||||
"""
|
||||
Store user-repository mappings in database
|
||||
"""Store user-repository mappings in database
|
||||
|
||||
1. Make sure to store mappings if they don't exist
|
||||
2. If a mapping already exists (same user_id and repo_id), update the admin field
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Common Room Sync
|
||||
"""Common Room Sync
|
||||
|
||||
This script queries the database to count conversations created by each user,
|
||||
then creates or updates a signal in Common Room for each user with their
|
||||
|
||||
@@ -14,8 +14,7 @@ data_collector = GitHubDataCollector()
|
||||
|
||||
|
||||
def get_unprocessed_prs() -> list[OpenhandsPR]:
|
||||
"""
|
||||
Get unprocessed PR entries from the OpenhandsPR table.
|
||||
"""Get unprocessed PR entries from the OpenhandsPR table.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of PRs to retrieve (default: 50)
|
||||
@@ -29,19 +28,14 @@ def get_unprocessed_prs() -> list[OpenhandsPR]:
|
||||
|
||||
|
||||
async def process_pr(pr: OpenhandsPR):
|
||||
"""
|
||||
Process a single PR to enrich its data.
|
||||
"""
|
||||
|
||||
"""Process a single PR to enrich its data."""
|
||||
logger.info(f'Processing PR #{pr.pr_number} from repo {pr.repo_name}')
|
||||
await data_collector.save_full_pr(pr)
|
||||
store.increment_process_attempts(pr.repo_id, pr.pr_number)
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Main function to retrieve and process unprocessed PRs.
|
||||
"""
|
||||
"""Main function to retrieve and process unprocessed PRs."""
|
||||
logger.info('Starting PR data enrichment process')
|
||||
|
||||
# Get unprocessed PRs
|
||||
|
||||
@@ -49,9 +49,7 @@ class VerifyWebhookStatus:
|
||||
webhook_store: GitlabWebhookStore,
|
||||
webhook: GitlabWebhook,
|
||||
):
|
||||
"""
|
||||
Check if the GitLab resource still exists
|
||||
"""
|
||||
"""Check if the GitLab resource still exists"""
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
|
||||
@@ -83,9 +81,7 @@ class VerifyWebhookStatus:
|
||||
webhook_store: GitlabWebhookStore,
|
||||
webhook: GitlabWebhook,
|
||||
):
|
||||
"""
|
||||
Check is user still has permission to resource
|
||||
"""
|
||||
"""Check is user still has permission to resource"""
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
|
||||
@@ -120,9 +116,7 @@ class VerifyWebhookStatus:
|
||||
webhook_store: GitlabWebhookStore,
|
||||
webhook: GitlabWebhook,
|
||||
):
|
||||
"""
|
||||
Check whether webhook already exists on resource
|
||||
"""
|
||||
"""Check whether webhook already exists on resource"""
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
|
||||
@@ -192,9 +186,7 @@ class VerifyWebhookStatus:
|
||||
webhook_store: GitlabWebhookStore,
|
||||
webhook: GitlabWebhook,
|
||||
):
|
||||
"""
|
||||
Install webhook on resource
|
||||
"""
|
||||
"""Install webhook on resource"""
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
|
||||
@@ -241,8 +233,7 @@ class VerifyWebhookStatus:
|
||||
)
|
||||
|
||||
async def install_webhooks(self):
|
||||
"""
|
||||
Periodically check the conditions for installing a webhook on resource as valid
|
||||
"""Periodically check the conditions for installing a webhook on resource as valid
|
||||
Rows with valid conditions with contain (webhook_exists=False, status=WebhookStatus.VERIFIED)
|
||||
|
||||
Conditions we check for
|
||||
@@ -255,7 +246,6 @@ class VerifyWebhookStatus:
|
||||
- resource was never setup with webhook
|
||||
|
||||
"""
|
||||
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
# Get an instance of the webhook store
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for Common Room conversation count sync.
|
||||
"""Test script for Common Room conversation count sync.
|
||||
|
||||
This script tests the functionality of the Common Room sync script
|
||||
without making any API calls to Common Room or database connections.
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Shared fixtures for Jira integration tests.
|
||||
"""
|
||||
"""Shared fixtures for Jira integration tests."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Unit tests for JiraManager.
|
||||
"""
|
||||
"""Unit tests for JiraManager."""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Tests for Jira view classes and factory.
|
||||
"""
|
||||
"""Tests for Jira view classes and factory."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Shared fixtures for Jira DC integration tests.
|
||||
"""
|
||||
"""Shared fixtures for Jira DC integration tests."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Unit tests for JiraDcManager.
|
||||
"""
|
||||
"""Unit tests for JiraDcManager."""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Tests for Jira DC view classes and factory.
|
||||
"""
|
||||
"""Tests for Jira DC view classes and factory."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Shared fixtures for Linear integration tests.
|
||||
"""
|
||||
"""Shared fixtures for Linear integration tests."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Unit tests for LinearManager.
|
||||
"""
|
||||
"""Unit tests for LinearManager."""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Tests for Linear view classes and factory.
|
||||
"""
|
||||
"""Tests for Linear view classes and factory."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Mock implementation of the stripe_service module for testing.
|
||||
"""
|
||||
"""Mock implementation of the stripe_service module for testing."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Tests for conversation_callback_utils.py
|
||||
"""
|
||||
"""Tests for conversation_callback_utils.py"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Shared fixtures for all tests.
|
||||
"""
|
||||
"""Shared fixtures for all tests."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
@@ -44,8 +42,7 @@ def feature_embedding() -> FeatureEmbedding:
|
||||
|
||||
@pytest.fixture
|
||||
def featurizer(mock_llm, features) -> Featurizer:
|
||||
"""
|
||||
Create a featurizer for testing.
|
||||
"""Create a featurizer for testing.
|
||||
|
||||
Mocks out any calls to LLM.completion
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Unit tests for data loading functionality in solvability/data.
|
||||
"""
|
||||
"""Unit tests for data loading functionality in solvability/data."""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
@@ -38,7 +38,6 @@ def mock_response():
|
||||
|
||||
def test_set_response_cookie(mock_response, mock_request):
|
||||
"""Test setting the auth cookie on a response."""
|
||||
|
||||
with patch('server.routes.auth.config') as mock_config:
|
||||
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
This test file verifies that the billing routes correctly use the stripe_service
|
||||
"""This test file verifies that the billing routes correctly use the stripe_service
|
||||
functions with the new database-first approach.
|
||||
"""
|
||||
|
||||
@@ -71,7 +70,6 @@ async def test_create_customer_setup_session_uses_customer_id():
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkout_session_uses_customer_id():
|
||||
"""Test that create_checkout_session uses a customer ID string"""
|
||||
|
||||
# Create a mock request
|
||||
mock_request = MagicMock()
|
||||
mock_request.state = {'user_id': 'test-user-id'}
|
||||
@@ -157,7 +155,6 @@ async def test_create_checkout_session_uses_customer_id():
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_payment_method_uses_customer_id():
|
||||
"""Test that has_payment_method uses a customer ID string"""
|
||||
|
||||
# Create a mock request
|
||||
mock_request = MagicMock()
|
||||
mock_request.state = {'user_id': 'test-user-id'}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user