mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
debug-logg
...
refactor/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a578ef2d22 |
@@ -8,7 +8,7 @@ services:
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-31536c8-python}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-0fdea73-python}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -26,14 +26,12 @@ from integrations.utils import (
|
||||
from integrations.v1_utils import get_saas_user_auth
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_error import ExpiredError
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
from openhands.server.types import (
|
||||
LLMAuthenticationError,
|
||||
MissingSettingsError,
|
||||
@@ -349,7 +347,7 @@ class GithubManager(Manager):
|
||||
|
||||
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except (AuthenticationError, ExpiredError, SessionExpiredError) as e:
|
||||
except SessionExpiredError as e:
|
||||
logger.warning(
|
||||
f'[GitHub] Session expired for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Add git_user_name and git_user_email columns to user table.
|
||||
|
||||
Revision ID: 090
|
||||
Revises: 089
|
||||
Create Date: 2025-01-22
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '090'
|
||||
down_revision = '089'
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'user',
|
||||
sa.Column('git_user_name', sa.String, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
'user',
|
||||
sa.Column('git_user_email', sa.String, nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('user', 'git_user_email')
|
||||
op.drop_column('user', 'git_user_name')
|
||||
24
enterprise/poetry.lock
generated
24
enterprise/poetry.lock
generated
@@ -6102,14 +6102,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.9.1"
|
||||
version = "1.9.0"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.9.1-py3-none-any.whl", hash = "sha256:ea1457760505b9ebfe6aabea08dedd010ce93aeb93edb450f00e25a0d056a723"},
|
||||
{file = "openhands_agent_server-1.9.1.tar.gz", hash = "sha256:d92a29a9d5aa94207519a5f8daad7c0a3d6641d5cba9f763f25aa4e85713fa0f"},
|
||||
{file = "openhands_agent_server-1.9.0-py3-none-any.whl", hash = "sha256:44b65fac5bb831541eb2e8726afb2682bde4816b4c6c90be9ad3cafd3dbcf971"},
|
||||
{file = "openhands_agent_server-1.9.0.tar.gz", hash = "sha256:ac41a948acf64ed661a9f383c293c305176f92bd12e6fc6362f5414cb7874ee1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6168,9 +6168,9 @@ memory-profiler = ">=0.61"
|
||||
numpy = "*"
|
||||
openai = "2.8"
|
||||
openhands-aci = "0.3.2"
|
||||
openhands-agent-server = "1.9.1"
|
||||
openhands-sdk = "1.9.1"
|
||||
openhands-tools = "1.9.1"
|
||||
openhands-agent-server = "1.9"
|
||||
openhands-sdk = "1.9"
|
||||
openhands-tools = "1.9"
|
||||
opentelemetry-api = ">=1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
|
||||
pathspec = ">=0.12.1"
|
||||
@@ -6225,14 +6225,14 @@ url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.9.1"
|
||||
version = "1.9.0"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.9.1-py3-none-any.whl", hash = "sha256:0e732dfe0d91289536ea0410db9554d5a5b0326f60e547ea7a9d8ddab5fe93e4"},
|
||||
{file = "openhands_sdk-1.9.1.tar.gz", hash = "sha256:c6ba33f85efa4c2ec63eb1040cbe82839662bcbcf323654ed071a9ad38ce7994"},
|
||||
{file = "openhands_sdk-1.9.0-py3-none-any.whl", hash = "sha256:b427d8b9e587a5360c7d61742c290601998557e9b38b1c9e11a297659812c00d"},
|
||||
{file = "openhands_sdk-1.9.0.tar.gz", hash = "sha256:70048888fd4fbe44a86c35c402bbb99d30cf0cba50579ee1a8e3f43e05154150"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6253,14 +6253,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.9.1"
|
||||
version = "1.9.0"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.9.1-py3-none-any.whl", hash = "sha256:411819657e00ffac5d5b1ba9adc6eb65a0a17cbefb5e3e1a34bb132ff61c59f2"},
|
||||
{file = "openhands_tools-1.9.1.tar.gz", hash = "sha256:331608994cce22b662038a2fed0bf7d2c1bb8dc27b1fc0a12a646e9bd76e0843"},
|
||||
{file = "openhands_tools-1.9.0-py3-none-any.whl", hash = "sha256:8becde0e913a31babb41eb93a8c10bf41d87ca1febd07bc958839c3583655305"},
|
||||
{file = "openhands_tools-1.9.0.tar.gz", hash = "sha256:d45f5f5210cb2bbcd8ab5f3a32051db1a532d0ec07cd306105f95cde42cf67f2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@@ -16,7 +16,6 @@ from keycloak.exceptions import (
|
||||
KeycloakError,
|
||||
KeycloakPostError,
|
||||
)
|
||||
from server.auth.auth_error import ExpiredError
|
||||
from server.auth.constants import (
|
||||
BITBUCKET_APP_CLIENT_ID,
|
||||
BITBUCKET_APP_CLIENT_SECRET,
|
||||
@@ -427,8 +426,6 @@ class TokenManager:
|
||||
access_token = data.get('access_token')
|
||||
refresh_token = data.get('refresh_token')
|
||||
if not access_token or not refresh_token:
|
||||
if data.get('error') == 'bad_refresh_token':
|
||||
raise ExpiredError()
|
||||
raise ValueError(
|
||||
'Failed to refresh token: missing access_token or refresh_token in response.'
|
||||
)
|
||||
|
||||
@@ -144,7 +144,7 @@ class SetAuthCookieMiddleware:
|
||||
# "if accepted_tos is not None" as there should not be any users with
|
||||
# accepted_tos equal to "None"
|
||||
if accepted_tos is False and request.url.path != '/api/accept_tos':
|
||||
logger.warning('User has not accepted the terms of service')
|
||||
logger.error('User has not accepted the terms of service')
|
||||
raise TosNotAcceptedError
|
||||
|
||||
def _should_attach(self, request: Request) -> bool:
|
||||
|
||||
@@ -13,33 +13,46 @@ from server.constants import (
|
||||
STRIPE_API_KEY,
|
||||
)
|
||||
from server.logger import logger
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.app_server.config import get_global_config
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
|
||||
|
||||
async def validate_billing_enabled() -> None:
|
||||
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
|
||||
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
|
||||
def is_all_hands_saas_environment(request: Request) -> bool:
|
||||
"""Check if the current domain is an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
|
||||
"""
|
||||
Validate that the billing feature flag is enabled
|
||||
hostname = request.url.hostname or ''
|
||||
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
|
||||
|
||||
|
||||
def validate_saas_environment(request: Request) -> None:
|
||||
"""Validate that the request is coming from an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Raises:
|
||||
HTTPException: If the request is not from an All Hands SaaS environment
|
||||
"""
|
||||
config = get_global_config()
|
||||
web_client_config = await config.web_client.get_web_client_config()
|
||||
if not web_client_config.feature_flags.enable_billing:
|
||||
if not is_all_hands_saas_environment(request):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=(
|
||||
'Billing is disabled in this environment. '
|
||||
'Please set OH_WEB_CLIENT_FEATURE_FLAGS_ENABLE_BILLING to enable billing.'
|
||||
),
|
||||
detail='Checkout sessions are only available for All Hands SaaS environments',
|
||||
)
|
||||
|
||||
|
||||
@@ -141,15 +154,14 @@ async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
|
||||
async def create_customer_setup_session(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
) -> CreateBillingSessionResponse:
|
||||
await validate_billing_enabled()
|
||||
validate_saas_environment(request)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
base_url = _get_base_url(request)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{base_url}?free_credits=success',
|
||||
cancel_url=f'{base_url}',
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
@@ -161,8 +173,8 @@ async def create_checkout_session(
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
await validate_billing_enabled()
|
||||
base_url = _get_base_url(request)
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_info['customer_id'],
|
||||
@@ -185,8 +197,8 @@ async def create_checkout_session(
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_checkout_session',
|
||||
@@ -277,7 +289,7 @@ async def success_callback(session_id: str, request: Request):
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
|
||||
f'{request.base_url}settings/billing?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
|
||||
@@ -305,13 +317,5 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
|
||||
f'{request.base_url}settings/billing?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
|
||||
def _get_base_url(request: Request) -> URL:
|
||||
# Never send any part of the credit card process over a non secure connection
|
||||
base_url = request.base_url
|
||||
if base_url.hostname != 'localhost':
|
||||
base_url = base_url.replace(scheme='https')
|
||||
return base_url
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import zlib
|
||||
from base64 import b64decode, b64encode
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
@@ -52,11 +51,7 @@ def add_github_proxy_routes(app: FastAPI):
|
||||
state_payload = json.dumps(
|
||||
[query_params['state'][0], query_params['redirect_uri'][0]]
|
||||
)
|
||||
# Compress before encrypting to reduce URL length
|
||||
# This is critical for feature deployments where reCAPTCHA tokens in state
|
||||
# can cause "URL too long" errors from GitHub
|
||||
compressed_payload = zlib.compress(state_payload.encode())
|
||||
state = b64encode(_fernet().encrypt(compressed_payload)).decode()
|
||||
state = b64encode(_fernet().encrypt(state_payload.encode())).decode()
|
||||
query_params['state'] = [state]
|
||||
query_params['redirect_uri'] = [
|
||||
f'https://{request.url.netloc}/github-proxy/callback'
|
||||
@@ -72,9 +67,7 @@ def add_github_proxy_routes(app: FastAPI):
|
||||
parsed_url = urlparse(str(request.url))
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
state = query_params['state'][0]
|
||||
# Decrypt and decompress (reverse of github_proxy_start)
|
||||
decrypted_payload = _fernet().decrypt(b64decode(state.encode()))
|
||||
decrypted_state = zlib.decompress(decrypted_payload).decode()
|
||||
decrypted_state = _fernet().decrypt(b64decode(state.encode())).decode()
|
||||
|
||||
# Build query Params
|
||||
state, redirect_uri = json.loads(decrypted_state)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from storage.org import Org
|
||||
|
||||
|
||||
class OrgCreationError(Exception):
|
||||
@@ -28,27 +27,6 @@ class OrgDatabaseError(OrgCreationError):
|
||||
pass
|
||||
|
||||
|
||||
class OrgDeletionError(Exception):
|
||||
"""Base exception for organization deletion errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OrgAuthorizationError(OrgDeletionError):
|
||||
"""Raised when user is not authorized to delete organization."""
|
||||
|
||||
def __init__(self, message: str = 'Not authorized to delete organization'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OrgNotFoundError(Exception):
|
||||
"""Raised when organization is not found or user doesn't have access."""
|
||||
|
||||
def __init__(self, org_id: str):
|
||||
self.org_id = org_id
|
||||
super().__init__(f'Organization with id "{org_id}" not found')
|
||||
|
||||
|
||||
class OrgCreate(BaseModel):
|
||||
"""Request model for creating a new organization."""
|
||||
|
||||
@@ -87,85 +65,3 @@ class OrgResponse(BaseModel):
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
credits: float | None = None
|
||||
|
||||
@classmethod
|
||||
def from_org(cls, org: Org, credits: float | None = None) -> 'OrgResponse':
|
||||
"""Create an OrgResponse from an Org entity.
|
||||
|
||||
Args:
|
||||
org: The organization entity to convert
|
||||
credits: Optional credits value (defaults to None)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The response model instance
|
||||
"""
|
||||
return cls(
|
||||
id=str(org.id),
|
||||
name=org.name,
|
||||
contact_name=org.contact_name,
|
||||
contact_email=org.contact_email,
|
||||
conversation_expiration=org.conversation_expiration,
|
||||
agent=org.agent,
|
||||
default_max_iterations=org.default_max_iterations,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
default_llm_model=org.default_llm_model,
|
||||
default_llm_api_key_for_byor=None,
|
||||
default_llm_base_url=org.default_llm_base_url,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser
|
||||
if org.enable_default_condenser is not None
|
||||
else True,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters
|
||||
if org.enable_proactive_conversation_starters is not None
|
||||
else True,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
org_version=org.org_version if org.org_version is not None else 0,
|
||||
mcp_config=org.mcp_config,
|
||||
search_api_key=None,
|
||||
sandbox_api_key=None,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
)
|
||||
|
||||
|
||||
class OrgPage(BaseModel):
|
||||
"""Paginated response model for organization list."""
|
||||
|
||||
items: list[OrgResponse]
|
||||
next_page_id: str | None = None
|
||||
|
||||
|
||||
class OrgUpdate(BaseModel):
|
||||
"""Request model for updating an organization."""
|
||||
|
||||
# Basic organization information (any authenticated user can update)
|
||||
contact_name: str | None = None
|
||||
contact_email: EmailStr | None = Field(default=None, strip_whitespace=True)
|
||||
conversation_expiration: int | None = None
|
||||
default_max_iterations: int | None = Field(default=None, gt=0)
|
||||
remote_runtime_resource_factor: int | None = Field(default=None, gt=0)
|
||||
billing_margin: float | None = Field(default=None, ge=0, le=1)
|
||||
enable_proactive_conversation_starters: bool | None = None
|
||||
sandbox_base_container_image: str | None = None
|
||||
sandbox_runtime_container_image: str | None = None
|
||||
mcp_config: dict | None = None
|
||||
sandbox_api_key: str | None = None
|
||||
max_budget_per_task: float | None = Field(default=None, gt=0)
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
|
||||
# LLM settings (require admin/owner role)
|
||||
default_llm_model: str | None = None
|
||||
default_llm_api_key_for_byor: str | None = None
|
||||
default_llm_base_url: str | None = None
|
||||
search_api_key: str | None = None
|
||||
security_analyzer: str | None = None
|
||||
agent: str | None = None
|
||||
confirmation_mode: bool | None = None
|
||||
enable_default_condenser: bool | None = None
|
||||
condenser_max_size: int | None = Field(default=None, ge=20)
|
||||
|
||||
@@ -1,98 +1,20 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgAuthorizationError,
|
||||
OrgCreate,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrgPage,
|
||||
OrgResponse,
|
||||
OrgUpdate,
|
||||
)
|
||||
from storage.org_service import OrgService
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
# Initialize API router
|
||||
org_router = APIRouter(prefix='/api/organizations')
|
||||
|
||||
|
||||
@org_router.get('', response_model=OrgPage)
|
||||
async def list_user_orgs(
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
] = 100,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgPage:
|
||||
"""List organizations for the authenticated user.
|
||||
|
||||
This endpoint returns a paginated list of all organizations that the
|
||||
authenticated user is a member of.
|
||||
|
||||
Args:
|
||||
page_id: Optional page ID (offset) for pagination
|
||||
limit: Maximum number of organizations to return (1-100, default 100)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgPage: Paginated list of organizations
|
||||
|
||||
Raises:
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
'Listing organizations for user',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'page_id': page_id,
|
||||
'limit': limit,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch organizations from service layer
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
user_id=user_id,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Convert Org entities to OrgResponse objects
|
||||
org_responses = [OrgResponse.from_org(org, credits=None) for org in orgs]
|
||||
|
||||
logger.info(
|
||||
'Successfully retrieved organizations',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_count': len(org_responses),
|
||||
'has_more': next_page_id is not None,
|
||||
},
|
||||
)
|
||||
|
||||
return OrgPage(items=org_responses, next_page_id=next_page_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error listing organizations',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve organizations',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post('', response_model=OrgResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_org(
|
||||
org_data: OrgCreate,
|
||||
@@ -136,7 +58,31 @@ async def create_org(
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
return OrgResponse(
|
||||
id=str(org.id),
|
||||
name=org.name,
|
||||
contact_name=org.contact_name,
|
||||
contact_email=org.contact_email,
|
||||
conversation_expiration=org.conversation_expiration,
|
||||
agent=org.agent,
|
||||
default_max_iterations=org.default_max_iterations,
|
||||
security_analyzer=org.security_analyzer,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
default_llm_model=org.default_llm_model,
|
||||
default_llm_base_url=org.default_llm_base_url,
|
||||
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
|
||||
enable_default_condenser=org.enable_default_condenser,
|
||||
billing_margin=org.billing_margin,
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
|
||||
sandbox_base_container_image=org.sandbox_base_container_image,
|
||||
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
|
||||
org_version=org.org_version,
|
||||
mcp_config=org.mcp_config,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
)
|
||||
except OrgNameExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -169,234 +115,3 @@ async def create_org(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
|
||||
async def get_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Get organization details by ID.
|
||||
|
||||
This endpoint allows authenticated users who are members of an organization
|
||||
to retrieve its details. Only members of the organization can access this endpoint.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
|
||||
HTTPException: 404 if organization not found or user is not a member
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
'Retrieving organization details',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to get organization with membership validation
|
||||
org = await OrgService.get_org_by_id(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error retrieving organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
|
||||
async def delete_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
) -> dict:
|
||||
"""Delete an organization.
|
||||
|
||||
This endpoint allows authenticated organization owners to delete their organization.
|
||||
All associated data including organization members, conversations, billing data,
|
||||
and external LiteLLM team resources will be permanently removed.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to delete
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message with deleted organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user is not the organization owner
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 500 if deletion fails
|
||||
"""
|
||||
logger.info(
|
||||
'Organization deletion requested',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to delete organization with cleanup
|
||||
deleted_org = await OrgService.delete_org_with_cleanup(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Organization deletion completed successfully',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'org_name': deleted_org.name,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
'message': 'Organization deleted successfully',
|
||||
'organization': {
|
||||
'id': str(deleted_org.id),
|
||||
'name': deleted_org.name,
|
||||
'contact_name': deleted_org.contact_name,
|
||||
'contact_email': deleted_org.contact_email,
|
||||
},
|
||||
}
|
||||
|
||||
except OrgNotFoundError as e:
|
||||
logger.warning(
|
||||
'Organization not found for deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgAuthorizationError as e:
|
||||
logger.warning(
|
||||
'User not authorized to delete organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database error during organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to delete organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error during organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.patch('/{org_id}', response_model=OrgResponse)
|
||||
async def update_org(
|
||||
org_id: UUID,
|
||||
update_data: OrgUpdate,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Update an existing organization.
|
||||
|
||||
This endpoint allows authenticated users to update organization settings.
|
||||
LLM-related settings require admin or owner role in the organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to update (UUID validated by FastAPI)
|
||||
update_data: Organization update data
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The updated organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
|
||||
HTTPException: 403 if user lacks permission for LLM settings
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 422 if validation errors occur (handled by FastAPI)
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
logger.info(
|
||||
'Updating organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to update organization with permission checks
|
||||
updated_org = await OrgService.update_org_with_permissions(
|
||||
org_id=org_id,
|
||||
update_data=update_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM (following same pattern as create endpoint)
|
||||
credits = await OrgService.get_org_credits(user_id, updated_org.id)
|
||||
|
||||
return OrgResponse.from_org(updated_org, credits=credits)
|
||||
|
||||
except ValueError as e:
|
||||
# Organization not found
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except PermissionError as e:
|
||||
# User lacks permission for LLM settings
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database operation failed',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error updating organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ from server.sharing.shared_conversation_models import (
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
StoredConversationMetadata,
|
||||
@@ -58,7 +57,7 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
include_sub_conversations: bool = False,
|
||||
) -> SharedConversationPage:
|
||||
"""Search for shared conversations."""
|
||||
query = self._public_select_with_saas_metadata()
|
||||
query = self._public_select()
|
||||
|
||||
# Conditionally exclude sub-conversations based on the parameter
|
||||
if not include_sub_conversations:
|
||||
@@ -105,17 +104,14 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
rows = result.scalars().all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [
|
||||
self._to_shared_conversation(stored, saas_metadata=saas_metadata)
|
||||
for stored, saas_metadata in rows
|
||||
]
|
||||
items = [self._to_shared_conversation(row) for row in rows]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
@@ -156,18 +152,17 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
self, conversation_id: UUID
|
||||
) -> SharedConversation | None:
|
||||
"""Get a single public conversation info, returning None if missing or not shared."""
|
||||
query = self._public_select_with_saas_metadata().where(
|
||||
query = self._public_select().where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
row = result.first()
|
||||
stored = result.scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
if stored is None:
|
||||
return None
|
||||
|
||||
stored, saas_metadata = row
|
||||
return self._to_shared_conversation(stored, saas_metadata=saas_metadata)
|
||||
return self._to_shared_conversation(stored)
|
||||
|
||||
def _public_select(self):
|
||||
"""Create a select query that only returns public conversations."""
|
||||
@@ -178,25 +173,6 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
return query
|
||||
|
||||
def _public_select_with_saas_metadata(self):
|
||||
"""Create a select query that returns public conversations with SAAS metadata.
|
||||
|
||||
This joins with conversation_metadata_saas to retrieve the user_id needed
|
||||
for constructing the correct event storage path. Uses LEFT OUTER JOIN to
|
||||
support conversations that may not have SAAS metadata (e.g., in tests).
|
||||
"""
|
||||
query = (
|
||||
select(StoredConversationMetadata, StoredConversationMetadataSaas)
|
||||
.outerjoin(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
.where(StoredConversationMetadata.public == True) # noqa: E712
|
||||
)
|
||||
return query
|
||||
|
||||
def _apply_filters(
|
||||
self,
|
||||
query,
|
||||
@@ -235,16 +211,9 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
def _to_shared_conversation(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas | None = None,
|
||||
sub_conversation_ids: list[UUID] | None = None,
|
||||
) -> SharedConversation:
|
||||
"""Convert StoredConversationMetadata to SharedConversation.
|
||||
|
||||
Args:
|
||||
stored: The base conversation metadata from conversation_metadata table.
|
||||
saas_metadata: Optional SAAS metadata containing user_id and org_id.
|
||||
sub_conversation_ids: Optional list of sub-conversation IDs.
|
||||
"""
|
||||
"""Convert StoredConversationMetadata to SharedConversation."""
|
||||
# V1 conversations should always have a sandbox_id
|
||||
sandbox_id = stored.sandbox_id
|
||||
assert sandbox_id is not None
|
||||
@@ -270,16 +239,9 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
|
||||
created_at = self._fix_timezone(stored.created_at)
|
||||
updated_at = self._fix_timezone(stored.last_updated_at)
|
||||
|
||||
# Get user_id from SAAS metadata if available
|
||||
created_by_user_id = (
|
||||
str(saas_metadata.user_id)
|
||||
if saas_metadata and saas_metadata.user_id
|
||||
else None
|
||||
)
|
||||
|
||||
return SharedConversation(
|
||||
id=UUID(stored.conversation_id),
|
||||
created_by_user_id=created_by_user_id,
|
||||
created_by_user_id=None, # user_id is no longer stored in conversation metadata
|
||||
sandbox_id=stored.sandbox_id,
|
||||
selected_repository=stored.selected_repository,
|
||||
selected_branch=stored.selected_branch,
|
||||
|
||||
@@ -96,7 +96,7 @@ class LiteLlmManager:
|
||||
user_settings: UserSettings,
|
||||
) -> UserSettings | None:
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:start',
|
||||
'SettingsStore:umigrate_lite_llm_entries:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
@@ -141,35 +141,19 @@ class LiteLlmManager:
|
||||
return None
|
||||
credits = max(max_budget - spend, 0.0)
|
||||
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:create_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_user',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_user(
|
||||
client, keycloak_user_id, max_budget=UNLIMITED_BUDGET_SETTING
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:add_user_to_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key:
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
@@ -178,10 +162,6 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_byor_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
@@ -189,10 +169,6 @@ class LiteLlmManager:
|
||||
team_id=org_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:end',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -9,11 +9,8 @@ from uuid import UUID as parse_uuid
|
||||
from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model
|
||||
from server.routes.org_models import (
|
||||
LiteLLMIntegrationError,
|
||||
OrgAuthorizationError,
|
||||
OrgDatabaseError,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrgUpdate,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
@@ -396,224 +393,6 @@ class OrgService:
|
||||
)
|
||||
return e
|
||||
|
||||
@staticmethod
|
||||
def has_admin_or_owner_role(user_id: str, org_id: UUID) -> bool:
|
||||
"""
|
||||
Check if user has admin or owner role in the specified organization.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
org_id: Organization ID to check membership in
|
||||
|
||||
Returns:
|
||||
bool: True if user has admin or owner role, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Parse user_id as UUID for database query
|
||||
user_uuid = parse_uuid(user_id)
|
||||
|
||||
# Get the user's membership in this organization
|
||||
# Note: The type annotation says int but the actual column is UUID
|
||||
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
|
||||
if not org_member:
|
||||
return False
|
||||
|
||||
# Get the role details
|
||||
role = RoleStore.get_role_by_id(org_member.role_id)
|
||||
if not role:
|
||||
return False
|
||||
|
||||
# Admin and owner roles have elevated permissions
|
||||
# Based on test files, both admin and owner have rank 1
|
||||
return role.name in ['admin', 'owner']
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'Error checking user role in organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_org_member(user_id: str, org_id: UUID) -> bool:
|
||||
"""
|
||||
Check if user is a member of the specified organization.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
org_id: Organization ID to check membership in
|
||||
|
||||
Returns:
|
||||
bool: True if user is a member, False otherwise
|
||||
"""
|
||||
try:
|
||||
user_uuid = parse_uuid(user_id)
|
||||
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
|
||||
return org_member is not None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'Error checking user membership in organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_llm_settings_fields() -> set[str]:
|
||||
"""
|
||||
Get the set of organization fields that are considered LLM settings
|
||||
and require admin/owner role to update.
|
||||
|
||||
Returns:
|
||||
set[str]: Set of field names that require elevated permissions
|
||||
"""
|
||||
return {
|
||||
'default_llm_model',
|
||||
'default_llm_api_key_for_byor',
|
||||
'default_llm_base_url',
|
||||
'search_api_key',
|
||||
'security_analyzer',
|
||||
'agent',
|
||||
'confirmation_mode',
|
||||
'enable_default_condenser',
|
||||
'condenser_max_size',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _has_llm_settings_updates(update_data: OrgUpdate) -> set[str]:
|
||||
"""
|
||||
Check if the update contains any LLM settings fields.
|
||||
|
||||
Args:
|
||||
update_data: The organization update data
|
||||
|
||||
Returns:
|
||||
set[str]: Set of LLM fields being updated (empty if none)
|
||||
"""
|
||||
llm_fields = OrgService._get_llm_settings_fields()
|
||||
update_dict = update_data.model_dump(exclude_none=True)
|
||||
return llm_fields.intersection(update_dict.keys())
|
||||
|
||||
@staticmethod
|
||||
async def update_org_with_permissions(
|
||||
org_id: UUID,
|
||||
update_data: OrgUpdate,
|
||||
user_id: str,
|
||||
) -> Org:
|
||||
"""
|
||||
Update organization with permission checks for LLM settings.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID to update
|
||||
update_data: Organization update data from request
|
||||
user_id: ID of the user requesting the update
|
||||
|
||||
Returns:
|
||||
Org: The updated organization object
|
||||
|
||||
Raises:
|
||||
ValueError: If organization not found
|
||||
PermissionError: If user is not a member, or lacks admin/owner role for LLM settings
|
||||
OrgDatabaseError: If database update fails
|
||||
"""
|
||||
logger.info(
|
||||
'Updating organization with permission checks',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'has_update_data': update_data is not None,
|
||||
},
|
||||
)
|
||||
|
||||
# Validate organization exists
|
||||
existing_org = OrgStore.get_org_by_id(org_id)
|
||||
if not existing_org:
|
||||
raise ValueError(f'Organization with ID {org_id} not found')
|
||||
|
||||
# Check if user is a member of this organization
|
||||
if not OrgService.is_org_member(user_id, org_id):
|
||||
logger.warning(
|
||||
'Non-member attempted to update organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
raise PermissionError(
|
||||
'User must be a member of the organization to update it'
|
||||
)
|
||||
|
||||
# Check if update contains any LLM settings
|
||||
llm_fields_being_updated = OrgService._has_llm_settings_updates(update_data)
|
||||
if llm_fields_being_updated:
|
||||
# Verify user has admin or owner role
|
||||
has_permission = OrgService.has_admin_or_owner_role(user_id, org_id)
|
||||
if not has_permission:
|
||||
logger.warning(
|
||||
'User attempted to update LLM settings without permission',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'attempted_fields': list(llm_fields_being_updated),
|
||||
},
|
||||
)
|
||||
raise PermissionError(
|
||||
'Admin or owner role required to update LLM settings'
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'User has permission to update LLM settings',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'llm_fields': list(llm_fields_being_updated),
|
||||
},
|
||||
)
|
||||
|
||||
# Convert to dict for OrgStore (excluding None values)
|
||||
update_dict = update_data.model_dump(exclude_none=True)
|
||||
if not update_dict:
|
||||
logger.info(
|
||||
'No fields to update',
|
||||
extra={'org_id': str(org_id), 'user_id': user_id},
|
||||
)
|
||||
return existing_org
|
||||
|
||||
# Perform the update
|
||||
try:
|
||||
updated_org = OrgStore.update_org(org_id, update_dict)
|
||||
if not updated_org:
|
||||
raise OrgDatabaseError('Failed to update organization in database')
|
||||
|
||||
logger.info(
|
||||
'Organization updated successfully',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'updated_fields': list(update_dict.keys()),
|
||||
},
|
||||
)
|
||||
|
||||
return updated_org
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to update organization',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to update organization: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
async def get_org_credits(user_id: str, org_id: UUID) -> float | None:
|
||||
"""
|
||||
@@ -662,183 +441,3 @@ class OrgService:
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs_paginated(
|
||||
user_id: str, page_id: str | None = None, limit: int = 100
|
||||
):
|
||||
"""
|
||||
Get paginated list of organizations for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
page_id: Optional page ID (offset as string) for pagination
|
||||
limit: Maximum number of organizations to return
|
||||
|
||||
Returns:
|
||||
Tuple of (list of Org objects, next_page_id or None)
|
||||
"""
|
||||
logger.debug(
|
||||
'Fetching paginated organizations for user',
|
||||
extra={'user_id': user_id, 'page_id': page_id, 'limit': limit},
|
||||
)
|
||||
|
||||
# Convert user_id string to UUID
|
||||
user_uuid = parse_uuid(user_id)
|
||||
|
||||
# Fetch organizations from store
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_uuid, page_id=page_id, limit=limit
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'Retrieved organizations for user',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_count': len(orgs),
|
||||
'has_more': next_page_id is not None,
|
||||
},
|
||||
)
|
||||
|
||||
return orgs, next_page_id
|
||||
|
||||
@staticmethod
|
||||
async def get_org_by_id(org_id: UUID, user_id: str) -> Org:
|
||||
"""
|
||||
Get organization by ID with membership validation.
|
||||
|
||||
This method verifies that the user is a member of the organization
|
||||
before returning the organization details.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
|
||||
Returns:
|
||||
Org: The organization object
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization not found or user is not a member
|
||||
"""
|
||||
logger.info(
|
||||
'Retrieving organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# Verify user is a member of the organization
|
||||
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
logger.warning(
|
||||
'User is not a member of organization or organization does not exist',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
# Retrieve organization
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
'Organization not found despite valid membership',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
logger.info(
|
||||
'Successfully retrieved organization',
|
||||
extra={
|
||||
'org_id': str(org.id),
|
||||
'org_name': org.name,
|
||||
'user_id': user_id,
|
||||
},
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def verify_owner_authorization(user_id: str, org_id: UUID) -> None:
|
||||
"""
|
||||
Verify that the user is the owner of the organization.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
org_id: Organization ID
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization doesn't exist
|
||||
OrgAuthorizationError: If user is not authorized to delete
|
||||
"""
|
||||
# Check if organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
# Check if user is a member of the organization
|
||||
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
raise OrgAuthorizationError('User is not a member of this organization')
|
||||
|
||||
# Check if user has owner role
|
||||
role = RoleStore.get_role_by_id(org_member.role_id)
|
||||
if not role or role.name != 'owner':
|
||||
raise OrgAuthorizationError(
|
||||
'Only organization owners can delete organizations'
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'User authorization verified for organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'role': role.name},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def delete_org_with_cleanup(user_id: str, org_id: UUID) -> Org:
|
||||
"""
|
||||
Delete organization with complete cleanup of all associated data.
|
||||
|
||||
This method performs the complete organization deletion workflow:
|
||||
1. Verifies user authorization (owner only)
|
||||
2. Performs database cascade deletion and LiteLLM cleanup in single transaction
|
||||
|
||||
Args:
|
||||
user_id: User ID requesting deletion (must be owner)
|
||||
org_id: Organization ID to delete
|
||||
|
||||
Returns:
|
||||
Org: The deleted organization details
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization doesn't exist
|
||||
OrgAuthorizationError: If user is not authorized to delete
|
||||
OrgDatabaseError: If database operations or LiteLLM cleanup fail
|
||||
"""
|
||||
logger.info(
|
||||
'Starting organization deletion',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# Step 1: Verify user authorization
|
||||
OrgService.verify_owner_authorization(user_id, org_id)
|
||||
|
||||
# Step 2: Perform database cascade deletion with LiteLLM cleanup in transaction
|
||||
try:
|
||||
deleted_org = await OrgStore.delete_org_cascade(org_id)
|
||||
if not deleted_org:
|
||||
# This shouldn't happen since we verified existence above
|
||||
raise OrgDatabaseError('Organization not found during deletion')
|
||||
|
||||
logger.info(
|
||||
'Organization deletion completed successfully',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'org_name': deleted_org.name,
|
||||
},
|
||||
)
|
||||
|
||||
return deleted_org
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Organization deletion failed',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')
|
||||
|
||||
@@ -10,10 +10,8 @@ from server.constants import (
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
@@ -98,63 +96,6 @@ class OrgStore:
|
||||
orgs = session.query(Org).all()
|
||||
return orgs
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs_paginated(
|
||||
user_id: UUID, page_id: str | None = None, limit: int = 100
|
||||
) -> tuple[list[Org], str | None]:
|
||||
"""
|
||||
Get paginated list of organizations for a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
page_id: Optional page ID (offset as string) for pagination
|
||||
limit: Maximum number of organizations to return
|
||||
|
||||
Returns:
|
||||
Tuple of (list of Org objects, next_page_id or None)
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Build query joining OrgMember with Org
|
||||
query = (
|
||||
session.query(Org)
|
||||
.join(OrgMember, Org.id == OrgMember.org_id)
|
||||
.filter(OrgMember.user_id == user_id)
|
||||
.order_by(Org.name)
|
||||
)
|
||||
|
||||
# Apply pagination offset
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Fetch limit + 1 to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
orgs = query.all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(orgs) > limit
|
||||
if has_more:
|
||||
orgs = orgs[:limit]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
# Validate org versions
|
||||
validated_orgs = [
|
||||
OrgStore._validate_org_version(org) for org in orgs if org
|
||||
]
|
||||
validated_orgs = [org for org in validated_orgs if org is not None]
|
||||
|
||||
return validated_orgs, next_page_id
|
||||
|
||||
@staticmethod
|
||||
def update_org(
|
||||
org_id: UUID,
|
||||
@@ -245,119 +186,3 @@ class OrgStore:
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
async def delete_org_cascade(org_id: UUID) -> Org | None:
|
||||
"""
|
||||
Delete organization and all associated data in cascade, including external LiteLLM cleanup.
|
||||
|
||||
Args:
|
||||
org_id: UUID of the organization to delete
|
||||
|
||||
Returns:
|
||||
Org: The deleted organization object, or None if not found
|
||||
|
||||
Raises:
|
||||
Exception: If database operations or LiteLLM cleanup fail
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# First get the organization to return it
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. Delete conversation data for organization conversations
|
||||
session.execute(
|
||||
text("""
|
||||
DELETE FROM conversation_metadata
|
||||
WHERE conversation_id IN (
|
||||
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
|
||||
)
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
session.execute(
|
||||
text("""
|
||||
DELETE FROM app_conversation_start_task
|
||||
WHERE app_conversation_id::text IN (
|
||||
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
|
||||
)
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 2. Delete organization-owned data tables (direct org_id foreign keys)
|
||||
session.execute(
|
||||
text('DELETE FROM billing_sessions WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text(
|
||||
'DELETE FROM conversation_metadata_saas WHERE org_id = :org_id'
|
||||
),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM custom_secrets WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM api_keys WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM slack_conversation WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM slack_users WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
session.execute(
|
||||
text('DELETE FROM stripe_customers WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 3. Delete organization memberships
|
||||
session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 4. Handle users with this as current_org_id
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE "user" SET current_org_id = NULL WHERE current_org_id = :org_id'
|
||||
),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 5. Finally delete the organization
|
||||
session.delete(org)
|
||||
|
||||
# 6. Clean up LiteLLM team before committing transaction
|
||||
logger.info(
|
||||
'Deleting LiteLLM team within database transaction',
|
||||
extra={'org_id': str(org_id)},
|
||||
)
|
||||
await LiteLlmManager.delete_team(str(org_id))
|
||||
|
||||
# 7. Commit all changes only if everything succeeded
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'Successfully deleted organization and all associated data including LiteLLM team',
|
||||
extra={'org_id': str(org_id), 'org_name': org.name},
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(
|
||||
'Failed to delete organization - transaction rolled back',
|
||||
extra={'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -4,9 +4,7 @@ Store class for managing roles.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.role import Role
|
||||
|
||||
|
||||
@@ -35,20 +33,6 @@ class RoleStore:
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
async def get_role_by_name_async(
|
||||
name: str,
|
||||
session: Optional[AsyncSession] = None,
|
||||
) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
if session is not None:
|
||||
result = await session.execute(select(Role).where(Role.name == name))
|
||||
return result.scalars().first()
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Role).where(Role.name == name))
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def list_roles() -> List[Role]:
|
||||
"""List all roles."""
|
||||
|
||||
@@ -31,8 +31,6 @@ class User(Base): # type: ignore
|
||||
user_consents_to_analytics = Column(Boolean, nullable=True)
|
||||
email = Column(String, nullable=True)
|
||||
email_verified = Column(Boolean, nullable=True)
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
role = relationship('Role', back_populates='users')
|
||||
|
||||
@@ -14,9 +14,9 @@ from server.constants import (
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.database import session_maker
|
||||
from storage.encrypt_utils import decrypt_legacy_model
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
@@ -116,7 +116,7 @@ class UserStore:
|
||||
redis_client = UserStore._get_redis_client()
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
'user_store:_acquire_user_creation_lock:no_redis_client',
|
||||
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True # Proceed without locking if Redis is unavailable
|
||||
@@ -159,20 +159,12 @@ class UserStore:
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
logger.info(
|
||||
'user_store:migrate_user:calling_litellm_migrate_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await LiteLlmManager.migrate_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'user_store:migrate_user:done_litellm_migrate_entries',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
custom_settings = UserStore._has_custom_settings(
|
||||
decrypted_user_settings, user_settings.user_version
|
||||
)
|
||||
@@ -180,15 +172,7 @@ class UserStore:
|
||||
# avoids circular reference. This migrate method is temprorary until all users are migrated.
|
||||
from integrations.stripe_service import migrate_customer
|
||||
|
||||
logger.info(
|
||||
'user_store:migrate_user:calling_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await migrate_customer(session, user_id, org)
|
||||
logger.info(
|
||||
'user_store:migrate_user:done_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
@@ -217,15 +201,7 @@ class UserStore:
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
logger.info(
|
||||
'user_store:migrate_user:calling_get_role_by_name',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
role = await RoleStore.get_role_by_name_async('owner')
|
||||
logger.info(
|
||||
'user_store:migrate_user:done_get_role_by_name',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
@@ -253,10 +229,6 @@ class UserStore:
|
||||
user_settings.already_migrated = True
|
||||
session.merge(user_settings)
|
||||
session.flush()
|
||||
logger.info(
|
||||
'user_store:migrate_user:session_flush_complete',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# need to migrate conversation metadata
|
||||
session.execute(
|
||||
@@ -324,10 +296,6 @@ class UserStore:
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
logger.info(
|
||||
'user_store:migrate_user:session_committed',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
@@ -354,7 +322,7 @@ class UserStore:
|
||||
):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'user_store:create_default_settings:waiting_for_lock',
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
call_async_from_sync(
|
||||
@@ -404,13 +372,13 @@ class UserStore:
|
||||
This is the preferred method when calling from an async context as it
|
||||
avoids event loop conflicts that can occur with the sync version.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
@@ -418,39 +386,32 @@ class UserStore:
|
||||
while not await UserStore._acquire_user_creation_lock(user_id):
|
||||
# The user is already being created in another thread / process
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:waiting_for_lock',
|
||||
'saas_settings_store:create_default_settings:waiting_for_lock',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
|
||||
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:start_migration',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_settings = result.scalars().first()
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id)
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:calling_migrate_user',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
user = await UserStore.migrate_user(
|
||||
user_id,
|
||||
user_settings,
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import SecretStr
|
||||
from server.routes.github_proxy import add_github_proxy_routes
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_github_proxy(monkeypatch):
|
||||
"""Create a FastAPI app with github proxy routes enabled."""
|
||||
# Enable the github proxy endpoints
|
||||
monkeypatch.setenv('GITHUB_PROXY_ENDPOINTS', '1')
|
||||
|
||||
# Mock the config to have a jwt_secret
|
||||
mock_config = type(
|
||||
'MockConfig', (), {'jwt_secret': SecretStr('test-secret-key-for-testing')}
|
||||
)()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
with patch('server.routes.github_proxy.GITHUB_PROXY_ENDPOINTS', True):
|
||||
with patch('server.routes.github_proxy.config', mock_config):
|
||||
add_github_proxy_routes(app)
|
||||
|
||||
# Return app and mock_config so we can use the same config in tests
|
||||
return app, mock_config
|
||||
|
||||
|
||||
def test_state_compress_encrypt_and_decrypt_decompress_roundtrip(
|
||||
app_with_github_proxy, monkeypatch
|
||||
):
|
||||
"""
|
||||
Verify the code path used by github_proxy_start -> github_proxy_callback:
|
||||
- compress payload, encrypt, base64-encode (what the start code does)
|
||||
- base64-decode, decrypt, decompress (what the callback code does)
|
||||
|
||||
This test exercises the actual endpoints to verify the roundtrip works correctly.
|
||||
"""
|
||||
app, mock_config = app_with_github_proxy
|
||||
client = TestClient(app)
|
||||
|
||||
original_state = 'some-state-value'
|
||||
original_redirect_uri = 'https://example.com/redirect'
|
||||
|
||||
# Call github_proxy_start endpoint - it should redirect to GitHub with encrypted state
|
||||
with patch('server.routes.github_proxy.config', mock_config):
|
||||
response = client.get(
|
||||
'/github-proxy/test-subdomain/login/oauth/authorize',
|
||||
params={
|
||||
'state': original_state,
|
||||
'redirect_uri': original_redirect_uri,
|
||||
'client_id': 'test-client-id',
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 307
|
||||
redirect_url = response.headers['location']
|
||||
|
||||
# Verify it redirects to GitHub
|
||||
assert redirect_url.startswith('https://github.com/login/oauth/authorize')
|
||||
|
||||
# Parse the redirect URL to get the encrypted state
|
||||
parsed = urlparse(redirect_url)
|
||||
query_params = parse_qs(parsed.query)
|
||||
encrypted_state = query_params['state'][0]
|
||||
|
||||
# The redirect_uri should now point to our callback
|
||||
assert 'github-proxy/callback' in query_params['redirect_uri'][0]
|
||||
|
||||
# Now simulate GitHub calling back with this encrypted state
|
||||
with patch('server.routes.github_proxy.config', mock_config):
|
||||
callback_response = client.get(
|
||||
'/github-proxy/callback',
|
||||
params={
|
||||
'state': encrypted_state,
|
||||
'code': 'test-auth-code',
|
||||
},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert callback_response.status_code == 307
|
||||
final_redirect = callback_response.headers['location']
|
||||
|
||||
# Verify the callback redirects to the original redirect_uri
|
||||
assert final_redirect.startswith(original_redirect_uri)
|
||||
|
||||
# Parse the final redirect to verify the state was decrypted correctly
|
||||
final_parsed = urlparse(final_redirect)
|
||||
final_params = parse_qs(final_parsed.query)
|
||||
|
||||
assert final_params['state'][0] == original_state
|
||||
assert final_params['code'][0] == 'test-auth-code'
|
||||
File diff suppressed because it is too large
Load Diff
@@ -163,7 +163,7 @@ async def test_create_checkout_session_stripe_error(
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('server.routes.billing.validate_billing_enabled'),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
await create_checkout_session(
|
||||
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
|
||||
@@ -204,7 +204,7 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('server.routes.billing.validate_billing_enabled'),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
@@ -236,8 +236,8 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
|
||||
mode='payment',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={'payment_method_save': 'enabled'},
|
||||
success_url='https://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
|
||||
cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
|
||||
success_url='http://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
|
||||
cancel_url='http://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
|
||||
)
|
||||
|
||||
# Verify database session creation
|
||||
@@ -331,7 +331,7 @@ async def test_success_callback_success():
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'https://test.com/settings/billing?checkout=success'
|
||||
== 'http://test.com/settings/billing?checkout=success'
|
||||
)
|
||||
|
||||
# Verify LiteLLM API calls
|
||||
@@ -402,7 +402,7 @@ async def test_cancel_callback_session_not_found():
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'https://test.com/settings/billing?checkout=cancel'
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify no database updates occurred
|
||||
@@ -429,7 +429,7 @@ async def test_cancel_callback_success():
|
||||
assert response.status_code == 302
|
||||
assert (
|
||||
response.headers['location']
|
||||
== 'https://test.com/settings/billing?checkout=cancel'
|
||||
== 'http://test.com/settings/billing?checkout=cancel'
|
||||
)
|
||||
|
||||
# Verify database updates
|
||||
@@ -490,7 +490,7 @@ async def test_create_customer_setup_session_success():
|
||||
AsyncMock(return_value=mock_customer_info),
|
||||
),
|
||||
patch('stripe.checkout.Session.create_async', mock_create),
|
||||
patch('server.routes.billing.validate_billing_enabled'),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
result = await create_customer_setup_session(mock_request, 'mock_user')
|
||||
|
||||
@@ -502,6 +502,6 @@ async def test_create_customer_setup_session_success():
|
||||
customer='mock-customer-id',
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url='https://test.com/?free_credits=success',
|
||||
cancel_url='https://test.com/',
|
||||
success_url='http://test.com/?free_credits=success',
|
||||
cancel_url='http://test.com/',
|
||||
)
|
||||
|
||||
@@ -68,84 +68,3 @@ def test_user_model(session_maker):
|
||||
)
|
||||
assert queried_org_member is not None
|
||||
assert queried_org_member.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
|
||||
|
||||
def test_user_model_git_user_fields(session_maker):
|
||||
"""Test that git_user_name and git_user_email columns exist and work correctly."""
|
||||
with session_maker() as session:
|
||||
# Arrange
|
||||
org = Org(name='test_org_git')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
test_user_id = uuid4()
|
||||
|
||||
# Act
|
||||
user = User(
|
||||
id=test_user_id,
|
||||
current_org_id=org.id,
|
||||
git_user_name='Test Git Author',
|
||||
git_user_email='git@example.com',
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Assert
|
||||
queried_user = session.query(User).filter(User.id == test_user_id).first()
|
||||
assert queried_user.git_user_name == 'Test Git Author'
|
||||
assert queried_user.git_user_email == 'git@example.com'
|
||||
|
||||
|
||||
def test_user_model_git_user_fields_nullable(session_maker):
|
||||
"""Test that git_user_name and git_user_email can be null."""
|
||||
with session_maker() as session:
|
||||
# Arrange
|
||||
org = Org(name='test_org_nullable')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
test_user_id = uuid4()
|
||||
|
||||
# Act - create user without git fields
|
||||
user = User(
|
||||
id=test_user_id,
|
||||
current_org_id=org.id,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Assert
|
||||
queried_user = session.query(User).filter(User.id == test_user_id).first()
|
||||
assert queried_user.git_user_name is None
|
||||
assert queried_user.git_user_email is None
|
||||
|
||||
|
||||
def test_user_model_git_user_fields_in_table_columns():
|
||||
"""Test that git_user_name and git_user_email are in User table columns."""
|
||||
# Arrange & Act
|
||||
column_names = [c.name for c in User.__table__.columns]
|
||||
|
||||
# Assert
|
||||
assert 'git_user_name' in column_names
|
||||
assert 'git_user_email' in column_names
|
||||
|
||||
|
||||
def test_user_model_git_user_fields_hasattr(session_maker):
|
||||
"""Test that hasattr returns True for git_user_* fields on User model.
|
||||
|
||||
This verifies the fix for SaasSettingsStore.store() which uses hasattr
|
||||
to determine if a field should be persisted to a model.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Arrange
|
||||
org = Org(name='test_org_hasattr')
|
||||
session.add(org)
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid4(), current_org_id=org.id)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
# Assert - hasattr must return True for store() to work
|
||||
assert hasattr(user, 'git_user_name')
|
||||
assert hasattr(user, 'git_user_email')
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -415,374 +415,3 @@ def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm
|
||||
)
|
||||
assert persisted_member.max_iterations == 100
|
||||
assert persisted_member.llm_model == 'gpt-4'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_org_cascade_success(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Valid organization with associated data
|
||||
WHEN: delete_org_cascade is called
|
||||
THEN: Organization and all associated data are deleted and org object is returned
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
# Create expected return object
|
||||
expected_org = Org(
|
||||
id=org_id,
|
||||
name='Test Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
|
||||
# Mock delete_org_cascade to avoid database schema constraints
|
||||
async def mock_delete_org_cascade(org_id_param):
|
||||
# Verify the method was called with correct parameter
|
||||
assert org_id_param == org_id
|
||||
|
||||
# Return the organization object (simulating successful deletion)
|
||||
return expected_org
|
||||
|
||||
with patch(
|
||||
'storage.org_store.OrgStore.delete_org_cascade', mock_delete_org_cascade
|
||||
):
|
||||
# Act
|
||||
result = await OrgStore.delete_org_cascade(org_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.id == org_id
|
||||
assert result.name == 'Test Organization'
|
||||
assert result.contact_name == 'John Doe'
|
||||
assert result.contact_email == 'john@example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_org_cascade_not_found(session_maker):
|
||||
"""
|
||||
GIVEN: Organization ID that doesn't exist
|
||||
WHEN: delete_org_cascade is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_id = uuid.uuid4()
|
||||
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
# Act
|
||||
result = await OrgStore.delete_org_cascade(non_existent_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_org_cascade_litellm_failure_causes_rollback(
|
||||
session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: Organization exists but LiteLLM cleanup fails
|
||||
WHEN: delete_org_cascade is called
|
||||
THEN: Transaction is rolled back and organization still exists
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
role = Role(id=1, name='owner', rank=1)
|
||||
user = User(id=user_id, current_org_id=org_id)
|
||||
org = Org(
|
||||
id=org_id,
|
||||
name='Test Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=1,
|
||||
status='active',
|
||||
llm_api_key='test-key',
|
||||
)
|
||||
session.add_all([role, user, org, org_member])
|
||||
session.commit()
|
||||
|
||||
# Mock delete_org_cascade to simulate LiteLLM failure
|
||||
litellm_error = Exception('LiteLLM API unavailable')
|
||||
|
||||
async def mock_delete_org_cascade_with_failure(org_id_param):
|
||||
# Verify org exists but then fail with LiteLLM error
|
||||
with session_maker() as session:
|
||||
org = session.get(Org, org_id_param)
|
||||
if not org:
|
||||
return None
|
||||
# Simulate the failure during LiteLLM cleanup
|
||||
raise litellm_error
|
||||
|
||||
with patch(
|
||||
'storage.org_store.OrgStore.delete_org_cascade',
|
||||
mock_delete_org_cascade_with_failure,
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await OrgStore.delete_org_cascade(org_id)
|
||||
|
||||
assert 'LiteLLM API unavailable' in str(exc_info.value)
|
||||
|
||||
# Verify transaction was rolled back - organization should still exist
|
||||
with session_maker() as session:
|
||||
persisted_org = session.get(Org, org_id)
|
||||
assert persisted_org is not None
|
||||
assert persisted_org.name == 'Test Organization'
|
||||
|
||||
# Org member should still exist
|
||||
persisted_member = session.query(OrgMember).filter_by(org_id=org_id).first()
|
||||
assert persisted_member is not None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User is member of multiple organizations
|
||||
WHEN: get_user_orgs_paginated is called without page_id
|
||||
THEN: First page of organizations is returned in alphabetical order
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
other_user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
# Create orgs for the user
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
org3 = Org(name='Gamma Org')
|
||||
# Create org for another user (should not be included)
|
||||
org4 = Org(name='Other Org')
|
||||
session.add_all([org1, org2, org3, org4])
|
||||
session.flush()
|
||||
|
||||
# Create user and role
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
other_user = User(id=other_user_id, current_org_id=org4.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, other_user, role])
|
||||
session.flush()
|
||||
|
||||
# Create memberships
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
member3 = OrgMember(
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
other_member = OrgMember(
|
||||
org_id=org4.id, user_id=other_user_id, role_id=1, llm_api_key='key4'
|
||||
)
|
||||
session.add_all([member1, member2, member3, other_member])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=2
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 2
|
||||
assert orgs[0].name == 'Alpha Org'
|
||||
assert orgs[1].name == 'Beta Org'
|
||||
assert next_page_id == '2' # Has more results
|
||||
# Verify other user's org is not included
|
||||
org_names = [org.name for org in orgs]
|
||||
assert 'Other Org' not in org_names
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has multiple organizations and page_id is provided
|
||||
WHEN: get_user_orgs_paginated is called with page_id
|
||||
THEN: Organizations starting from offset are returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
org3 = Org(name='Gamma Org')
|
||||
session.add_all([org1, org2, org3])
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
member3 = OrgMember(
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
session.add_all([member1, member2, member3])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id='1', limit=1
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == 'Beta Org' # Second org (offset 1)
|
||||
assert next_page_id == '2' # Has more results
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has organizations but fewer than limit
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: All organizations are returned and next_page_id is None
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
org2 = Org(name='Beta Org')
|
||||
session.add_all([org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
session.add_all([member1, member2])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 2
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_invalid_page_id(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: Invalid page_id (non-numeric string)
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: Results start from beginning (offset 0)
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
org1 = Org(name='Alpha Org')
|
||||
session.add(org1)
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
session.add(member1)
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id='invalid', limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == 'Alpha Org'
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_empty_results(session_maker):
|
||||
"""
|
||||
GIVEN: User has no organizations
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: Empty list and None next_page_id are returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 0
|
||||
assert next_page_id is None
|
||||
|
||||
|
||||
def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api):
|
||||
"""
|
||||
GIVEN: User has organizations with different names
|
||||
WHEN: get_user_orgs_paginated is called
|
||||
THEN: Organizations are returned in alphabetical order by name
|
||||
"""
|
||||
# Arrange
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
# Create orgs in non-alphabetical order
|
||||
org3 = Org(name='Zebra Org')
|
||||
org1 = Org(name='Apple Org')
|
||||
org2 = Org(name='Banana Org')
|
||||
session.add_all([org3, org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=user_id, current_org_id=org1.id)
|
||||
role = Role(id=1, name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
member1 = OrgMember(
|
||||
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
|
||||
)
|
||||
member2 = OrgMember(
|
||||
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
|
||||
)
|
||||
member3 = OrgMember(
|
||||
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
|
||||
)
|
||||
session.add_all([member1, member2, member3])
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.org_store.session_maker', session_maker):
|
||||
orgs, _ = OrgStore.get_user_orgs_paginated(
|
||||
user_id=user_id, page_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(orgs) == 3
|
||||
assert orgs[0].name == 'Apple Org'
|
||||
assert orgs[1].name == 'Banana Org'
|
||||
assert orgs[2].name == 'Zebra Org'
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from server.sharing.shared_conversation_models import (
|
||||
@@ -13,9 +13,6 @@ from server.sharing.sql_shared_conversation_info_service import (
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.org import Org
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user import User
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
@@ -431,261 +428,3 @@ class TestSharedConversationInfoService:
|
||||
page1_ids = {item.id for item in result.items}
|
||||
page2_ids = {item.id for item in result2.items}
|
||||
assert page1_ids.isdisjoint(page2_ids)
|
||||
|
||||
|
||||
class TestSharedConversationInfoServiceWithSaasMetadata:
|
||||
"""Test cases for SharedConversationInfoService with SAAS metadata.
|
||||
|
||||
These tests verify that created_by_user_id is correctly retrieved from
|
||||
the conversation_metadata_saas table when it exists.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine_with_saas(self):
|
||||
"""Create an async SQLite engine with all SAAS tables."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_with_saas(
|
||||
self, async_engine_with_saas
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session for testing with SAAS tables."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine_with_saas, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
yield db_session
|
||||
|
||||
@pytest.fixture
|
||||
async def test_org(self, async_session_with_saas) -> Org:
|
||||
"""Create a test organization."""
|
||||
org = Org(id=uuid4(), name=f'test_org_{uuid4().hex[:8]}')
|
||||
async_session_with_saas.add(org)
|
||||
await async_session_with_saas.commit()
|
||||
return org
|
||||
|
||||
@pytest.fixture
|
||||
async def test_user(self, async_session_with_saas, test_org) -> User:
|
||||
"""Create a test user belonging to the test organization."""
|
||||
user = User(id=uuid4(), current_org_id=test_org.id)
|
||||
async_session_with_saas.add(user)
|
||||
await async_session_with_saas.commit()
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
async def shared_service_with_saas(self, async_session_with_saas):
|
||||
"""Create a SharedConversationInfoService for testing."""
|
||||
return SQLSharedConversationInfoService(db_session=async_session_with_saas)
|
||||
|
||||
@pytest.fixture
|
||||
async def app_service_with_saas(self, async_session_with_saas):
|
||||
"""Create an AppConversationInfoService for creating test data."""
|
||||
return SQLAppConversationInfoService(
|
||||
db_session=async_session_with_saas,
|
||||
user_context=SpecifyUserContext(user_id=None),
|
||||
)
|
||||
|
||||
async def _create_saas_metadata(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
conversation_id: UUID,
|
||||
user_id: UUID,
|
||||
org_id: UUID,
|
||||
) -> StoredConversationMetadataSaas:
|
||||
"""Helper to create SAAS metadata for a conversation."""
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=str(conversation_id),
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
db_session.add(saas_metadata)
|
||||
await db_session.commit()
|
||||
return saas_metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_shared_conversation_returns_user_id_from_saas_metadata(
|
||||
self,
|
||||
shared_service_with_saas,
|
||||
app_service_with_saas,
|
||||
async_session_with_saas,
|
||||
test_user,
|
||||
test_org,
|
||||
):
|
||||
"""Test that get_shared_conversation_info returns created_by_user_id from SAAS metadata."""
|
||||
# Arrange
|
||||
conversation_id = uuid4()
|
||||
conversation = AppConversationInfo(
|
||||
id=conversation_id,
|
||||
created_by_user_id=None,
|
||||
sandbox_id='test_sandbox',
|
||||
title='Public Conversation With User',
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
await app_service_with_saas.save_app_conversation_info(conversation)
|
||||
await self._create_saas_metadata(
|
||||
async_session_with_saas, conversation_id, test_user.id, test_org.id
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await shared_service_with_saas.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.created_by_user_id == str(test_user.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_shared_conversations_returns_user_id_from_saas_metadata(
|
||||
self,
|
||||
shared_service_with_saas,
|
||||
app_service_with_saas,
|
||||
async_session_with_saas,
|
||||
test_user,
|
||||
test_org,
|
||||
):
|
||||
"""Test that search_shared_conversation_info returns created_by_user_id from SAAS metadata."""
|
||||
# Arrange
|
||||
conversation_id = uuid4()
|
||||
conversation = AppConversationInfo(
|
||||
id=conversation_id,
|
||||
created_by_user_id=None,
|
||||
sandbox_id='test_sandbox_search',
|
||||
title='Searchable Public Conversation',
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
await app_service_with_saas.save_app_conversation_info(conversation)
|
||||
await self._create_saas_metadata(
|
||||
async_session_with_saas, conversation_id, test_user.id, test_org.id
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await shared_service_with_saas.search_shared_conversation_info()
|
||||
|
||||
# Assert
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].created_by_user_id == str(test_user.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_get_shared_conversations_returns_user_id_from_saas_metadata(
|
||||
self,
|
||||
shared_service_with_saas,
|
||||
app_service_with_saas,
|
||||
async_session_with_saas,
|
||||
test_user,
|
||||
test_org,
|
||||
):
|
||||
"""Test that batch_get_shared_conversation_info returns created_by_user_id from SAAS metadata."""
|
||||
# Arrange
|
||||
conversation_id = uuid4()
|
||||
conversation = AppConversationInfo(
|
||||
id=conversation_id,
|
||||
created_by_user_id=None,
|
||||
sandbox_id='test_sandbox_batch',
|
||||
title='Batch Get Conversation',
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
await app_service_with_saas.save_app_conversation_info(conversation)
|
||||
await self._create_saas_metadata(
|
||||
async_session_with_saas, conversation_id, test_user.id, test_org.id
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await shared_service_with_saas.batch_get_shared_conversation_info(
|
||||
[conversation_id]
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0] is not None
|
||||
assert result[0].created_by_user_id == str(test_user.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_conversations_with_and_without_saas_metadata(
|
||||
self,
|
||||
shared_service_with_saas,
|
||||
app_service_with_saas,
|
||||
async_session_with_saas,
|
||||
test_user,
|
||||
test_org,
|
||||
):
|
||||
"""Test handling of conversations where some have SAAS metadata and some don't."""
|
||||
# Arrange
|
||||
conv_with_saas_id = uuid4()
|
||||
conv_without_saas_id = uuid4()
|
||||
|
||||
conv_with_saas = AppConversationInfo(
|
||||
id=conv_with_saas_id,
|
||||
created_by_user_id=None,
|
||||
sandbox_id='sandbox_with_saas',
|
||||
title='With SAAS Metadata',
|
||||
created_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
conv_without_saas = AppConversationInfo(
|
||||
id=conv_without_saas_id,
|
||||
created_by_user_id=None,
|
||||
sandbox_id='sandbox_without_saas',
|
||||
title='Without SAAS Metadata',
|
||||
created_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
|
||||
public=True,
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
|
||||
await app_service_with_saas.save_app_conversation_info(conv_with_saas)
|
||||
await app_service_with_saas.save_app_conversation_info(conv_without_saas)
|
||||
await self._create_saas_metadata(
|
||||
async_session_with_saas, conv_with_saas_id, test_user.id, test_org.id
|
||||
)
|
||||
|
||||
# Act
|
||||
result = await shared_service_with_saas.search_shared_conversation_info(
|
||||
sort_order=SharedConversationSortOrder.CREATED_AT
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.items) == 2
|
||||
conv_without = next(
|
||||
item for item in result.items if item.id == conv_without_saas_id
|
||||
)
|
||||
conv_with = next(item for item in result.items if item.id == conv_with_saas_id)
|
||||
assert conv_without.created_by_user_id is None
|
||||
assert conv_with.created_by_user_id == str(test_user.id)
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { PlanPreview } from "#/components/features/chat/plan-preview";
|
||||
|
||||
// Mock the feature flag to always return true (not testing feature flag behavior)
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
USE_PLANNING_AGENT: vi.fn(() => true),
|
||||
}));
|
||||
|
||||
// Mock i18n - need to preserve initReactI18next and I18nextProvider for test-utils
|
||||
vi.mock("react-i18next", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("react-i18next")>();
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
describe("PlanPreview", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should render nothing when planContent is null", () => {
|
||||
renderWithProviders(<PlanPreview planContent={null} />);
|
||||
|
||||
const contentDiv = screen.getByTestId("plan-preview-content");
|
||||
expect(contentDiv).toBeInTheDocument();
|
||||
expect(contentDiv.textContent?.trim() || "").toBe("");
|
||||
});
|
||||
|
||||
it("should render nothing when planContent is undefined", () => {
|
||||
renderWithProviders(<PlanPreview planContent={undefined} />);
|
||||
|
||||
const contentDiv = screen.getByTestId("plan-preview-content");
|
||||
expect(contentDiv).toBeInTheDocument();
|
||||
expect(contentDiv.textContent?.trim() || "").toBe("");
|
||||
});
|
||||
|
||||
it("should render markdown content when planContent is provided", () => {
|
||||
const planContent = "# Plan Title\n\nThis is the plan content.";
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<PlanPreview planContent={planContent} />,
|
||||
);
|
||||
|
||||
// Check that component rendered and contains the content (markdown may break up text)
|
||||
expect(container.firstChild).not.toBeNull();
|
||||
expect(container.textContent).toContain("Plan Title");
|
||||
expect(container.textContent).toContain("This is the plan content.");
|
||||
});
|
||||
|
||||
it("should render full content when length is less than or equal to 300 characters", () => {
|
||||
const planContent = "A".repeat(300);
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<PlanPreview planContent={planContent} />,
|
||||
);
|
||||
|
||||
// Content should be present (may be broken up by markdown)
|
||||
expect(container.textContent).toContain(planContent);
|
||||
expect(screen.queryByText(/COMMON\$READ_MORE/i)).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should truncate content when length exceeds 300 characters", () => {
|
||||
const longContent = "A".repeat(350);
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<PlanPreview planContent={longContent} />,
|
||||
);
|
||||
|
||||
// Truncated content should be present (may be broken up by markdown)
|
||||
expect(container.textContent).toContain("A".repeat(300));
|
||||
expect(container.textContent).toContain("...");
|
||||
expect(container.textContent).toContain("COMMON$READ_MORE");
|
||||
});
|
||||
|
||||
it("should call onViewClick when View button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onViewClick = vi.fn();
|
||||
|
||||
renderWithProviders(
|
||||
<PlanPreview planContent="Plan content" onViewClick={onViewClick} />,
|
||||
);
|
||||
|
||||
const viewButton = screen.getByTestId("plan-preview-view-button");
|
||||
expect(viewButton).toBeInTheDocument();
|
||||
|
||||
await user.click(viewButton);
|
||||
|
||||
expect(onViewClick).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should call onViewClick when Read More button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onViewClick = vi.fn();
|
||||
const longContent = "A".repeat(350);
|
||||
|
||||
renderWithProviders(
|
||||
<PlanPreview planContent={longContent} onViewClick={onViewClick} />,
|
||||
);
|
||||
|
||||
const readMoreButton = screen.getByTestId("plan-preview-read-more-button");
|
||||
expect(readMoreButton).toBeInTheDocument();
|
||||
|
||||
await user.click(readMoreButton);
|
||||
|
||||
expect(onViewClick).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should call onBuildClick when Build button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onBuildClick = vi.fn();
|
||||
|
||||
renderWithProviders(
|
||||
<PlanPreview planContent="Plan content" onBuildClick={onBuildClick} />,
|
||||
);
|
||||
|
||||
const buildButton = screen.getByTestId("plan-preview-build-button");
|
||||
expect(buildButton).toBeInTheDocument();
|
||||
|
||||
await user.click(buildButton);
|
||||
|
||||
expect(onBuildClick).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should render header with PLAN_MD text", () => {
|
||||
const { container } = renderWithProviders(
|
||||
<PlanPreview planContent="Plan content" />,
|
||||
);
|
||||
|
||||
// Check that the translation key is rendered (i18n mock returns the key)
|
||||
expect(container.textContent).toContain("COMMON$PLAN_MD");
|
||||
});
|
||||
|
||||
it("should render plan content", () => {
|
||||
const planContent = `# Heading 1
|
||||
## Heading 2
|
||||
- List item 1
|
||||
- List item 2
|
||||
|
||||
**Bold text** and *italic text*`;
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<PlanPreview planContent={planContent} />,
|
||||
);
|
||||
|
||||
expect(container.textContent).toContain("Heading 1");
|
||||
expect(container.textContent).toContain("Heading 2");
|
||||
});
|
||||
});
|
||||
@@ -1,35 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { shouldRenderEvent } from "#/components/v1/chat/event-content-helpers/should-render-event";
|
||||
import {
|
||||
createPlanningFileEditorActionEvent,
|
||||
createOtherActionEvent,
|
||||
createPlanningObservationEvent,
|
||||
createUserMessageEvent,
|
||||
} from "test-utils";
|
||||
|
||||
describe("shouldRenderEvent - PlanningFileEditorAction", () => {
|
||||
it("should return false for PlanningFileEditorAction", () => {
|
||||
const event = createPlanningFileEditorActionEvent("action-1");
|
||||
|
||||
expect(shouldRenderEvent(event)).toBe(false);
|
||||
});
|
||||
|
||||
it("should return true for other action types", () => {
|
||||
const event = createOtherActionEvent("action-1");
|
||||
|
||||
expect(shouldRenderEvent(event)).toBe(true);
|
||||
});
|
||||
|
||||
it("should return true for PlanningFileEditorObservation", () => {
|
||||
const event = createPlanningObservationEvent("obs-1");
|
||||
|
||||
// Observations should still render (they're handled separately in event-message)
|
||||
expect(shouldRenderEvent(event)).toBe(true);
|
||||
});
|
||||
|
||||
it("should return true for user message events", () => {
|
||||
const event = createUserMessageEvent("msg-1");
|
||||
|
||||
expect(shouldRenderEvent(event)).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -1,159 +0,0 @@
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { screen, render } from "@testing-library/react";
|
||||
import { EventMessage } from "#/components/v1/chat/event-message";
|
||||
import { useConversationStore } from "#/stores/conversation-store";
|
||||
import {
|
||||
renderWithProviders,
|
||||
createPlanningObservationEvent,
|
||||
} from "test-utils";
|
||||
|
||||
// Mock the feature flag
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
USE_PLANNING_AGENT: vi.fn(() => true),
|
||||
}));
|
||||
|
||||
// Mock useConfig
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => ({
|
||||
data: { APP_MODE: "saas" },
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock PlanPreview component to verify it's rendered
|
||||
vi.mock("#/components/features/chat/plan-preview", () => ({
|
||||
PlanPreview: ({ planContent }: { planContent?: string | null }) => (
|
||||
<div data-testid="plan-preview">Plan Preview: {planContent || "null"}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
describe("EventMessage - PlanPreview rendering", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
// Reset conversation store
|
||||
useConversationStore.setState({
|
||||
planContent: null,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should render PlanPreview when PlanningFileEditorObservation event ID is in planPreviewEventIds", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
const planPreviewEventIds = new Set(["plan-obs-1"]);
|
||||
const planContent = "This is the plan content";
|
||||
|
||||
useConversationStore.setState({ planContent });
|
||||
|
||||
renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("plan-preview")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(`Plan Preview: ${planContent}`),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should return null when PlanningFileEditorObservation event ID is NOT in planPreviewEventIds", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
const planPreviewEventIds = new Set(["plan-obs-2"]); // Different ID
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null when planPreviewEventIds is undefined", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={undefined}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
it("should use planContent from conversation store", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
const planPreviewEventIds = new Set(["plan-obs-1"]);
|
||||
const planContent = "Store plan content";
|
||||
|
||||
useConversationStore.setState({ planContent });
|
||||
|
||||
renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByText(`Plan Preview: ${planContent}`),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should handle null planContent from store", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
const planPreviewEventIds = new Set(["plan-obs-1"]);
|
||||
|
||||
useConversationStore.setState({ planContent: null });
|
||||
|
||||
renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("plan-preview")).toBeInTheDocument();
|
||||
expect(screen.getByText("Plan Preview: null")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should handle empty planPreviewEventIds set", () => {
|
||||
const event = createPlanningObservationEvent("plan-obs-1");
|
||||
const planPreviewEventIds = new Set<string>();
|
||||
|
||||
const { container } = renderWithProviders(
|
||||
<EventMessage
|
||||
event={event}
|
||||
messages={[]}
|
||||
isLastMessage={false}
|
||||
isInLast10Actions={false}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -1,195 +0,0 @@
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
usePlanPreviewEvents,
|
||||
shouldShowPlanPreview,
|
||||
} from "#/components/v1/chat/hooks/use-plan-preview-events";
|
||||
import {
|
||||
OpenHandsEvent,
|
||||
MessageEvent,
|
||||
ObservationEvent,
|
||||
PlanningFileEditorObservation,
|
||||
} from "#/types/v1/core";
|
||||
|
||||
// Helper to create a user message event
|
||||
const createUserMessageEvent = (id: string): MessageEvent => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "user",
|
||||
llm_message: {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "User message" }],
|
||||
},
|
||||
activated_microagents: [],
|
||||
extended_content: [],
|
||||
});
|
||||
|
||||
// Helper to create a PlanningFileEditorObservation event
|
||||
const createPlanningObservationEvent = (
|
||||
id: string,
|
||||
actionId: string = "action-1",
|
||||
): ObservationEvent<PlanningFileEditorObservation> => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "environment",
|
||||
tool_name: "planning_file_editor",
|
||||
tool_call_id: "call-1",
|
||||
action_id: actionId,
|
||||
observation: {
|
||||
kind: "PlanningFileEditorObservation",
|
||||
content: [{ type: "text", text: "Plan content" }],
|
||||
is_error: false,
|
||||
command: "create",
|
||||
path: "/workspace/PLAN.md",
|
||||
prev_exist: false,
|
||||
old_content: null,
|
||||
new_content: "Plan content",
|
||||
},
|
||||
});
|
||||
|
||||
// Helper to create a non-planning observation event
|
||||
const createOtherObservationEvent = (id: string): ObservationEvent => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "environment",
|
||||
tool_name: "execute_bash",
|
||||
tool_call_id: "call-1",
|
||||
action_id: "action-1",
|
||||
observation: {
|
||||
kind: "ExecuteBashObservation",
|
||||
content: [{ type: "text", text: "output" }],
|
||||
command: "echo test",
|
||||
exit_code: 0,
|
||||
error: false,
|
||||
timeout: false,
|
||||
metadata: {
|
||||
exit_code: 0,
|
||||
pid: 12345,
|
||||
username: "user",
|
||||
hostname: "localhost",
|
||||
working_dir: "/home/user",
|
||||
py_interpreter_path: null,
|
||||
prefix: "",
|
||||
suffix: "",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
describe("usePlanPreviewEvents", () => {
|
||||
it("should return empty set when no events provided", () => {
|
||||
const { result } = renderHook(() => usePlanPreviewEvents([]));
|
||||
|
||||
expect(result.current).toBeInstanceOf(Set);
|
||||
expect(result.current.size).toBe(0);
|
||||
});
|
||||
|
||||
it("should return empty set when no PlanningFileEditorObservation events exist", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createUserMessageEvent("user-1"),
|
||||
createOtherObservationEvent("obs-1"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
expect(result.current.size).toBe(0);
|
||||
});
|
||||
|
||||
it("should return event ID for single PlanningFileEditorObservation in one phase", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createUserMessageEvent("user-1"),
|
||||
createPlanningObservationEvent("plan-obs-1"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
expect(result.current.size).toBe(1);
|
||||
expect(result.current.has("plan-obs-1")).toBe(true);
|
||||
});
|
||||
|
||||
it("should return only the last PlanningFileEditorObservation when multiple exist in one phase", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createUserMessageEvent("user-1"),
|
||||
createPlanningObservationEvent("plan-obs-1"),
|
||||
createPlanningObservationEvent("plan-obs-2"),
|
||||
createPlanningObservationEvent("plan-obs-3"),
|
||||
createOtherObservationEvent("other-obs-1"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
// Should only include the last one in the phase
|
||||
expect(result.current.size).toBe(1);
|
||||
expect(result.current.has("plan-obs-1")).toBe(false);
|
||||
expect(result.current.has("plan-obs-2")).toBe(false);
|
||||
expect(result.current.has("plan-obs-3")).toBe(true);
|
||||
});
|
||||
|
||||
it("should return one event ID per phase when multiple phases exist", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createUserMessageEvent("user-1"),
|
||||
createPlanningObservationEvent("plan-obs-1"),
|
||||
createPlanningObservationEvent("plan-obs-2"),
|
||||
createUserMessageEvent("user-2"),
|
||||
createPlanningObservationEvent("plan-obs-3"),
|
||||
createPlanningObservationEvent("plan-obs-4"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
// Should have one preview per phase (last observation in each phase)
|
||||
expect(result.current.size).toBe(2);
|
||||
expect(result.current.has("plan-obs-2")).toBe(true); // Last in phase 1
|
||||
expect(result.current.has("plan-obs-4")).toBe(true); // Last in phase 2
|
||||
expect(result.current.has("plan-obs-1")).toBe(false);
|
||||
expect(result.current.has("plan-obs-3")).toBe(false);
|
||||
});
|
||||
|
||||
it("should handle phase with no PlanningFileEditorObservation", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createUserMessageEvent("user-1"),
|
||||
createOtherObservationEvent("obs-1"),
|
||||
createUserMessageEvent("user-2"),
|
||||
createPlanningObservationEvent("plan-obs-1"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
// Only phase 2 has a planning observation
|
||||
expect(result.current.size).toBe(1);
|
||||
expect(result.current.has("plan-obs-1")).toBe(true);
|
||||
});
|
||||
|
||||
it("should handle events starting with non-user message", () => {
|
||||
const events: OpenHandsEvent[] = [
|
||||
createOtherObservationEvent("obs-1"),
|
||||
createUserMessageEvent("user-1"),
|
||||
createPlanningObservationEvent("plan-obs-1"),
|
||||
];
|
||||
|
||||
const { result } = renderHook(() => usePlanPreviewEvents(events));
|
||||
|
||||
// Events before first user message should be in first phase
|
||||
expect(result.current.size).toBe(1);
|
||||
expect(result.current.has("plan-obs-1")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("shouldShowPlanPreview", () => {
|
||||
it("should return true when event ID is in the set", () => {
|
||||
const planPreviewEventIds = new Set(["event-1", "event-2", "event-3"]);
|
||||
|
||||
expect(shouldShowPlanPreview("event-2", planPreviewEventIds)).toBe(true);
|
||||
});
|
||||
|
||||
it("should return false when event ID is not in the set", () => {
|
||||
const planPreviewEventIds = new Set(["event-1", "event-2"]);
|
||||
|
||||
expect(shouldShowPlanPreview("event-3", planPreviewEventIds)).toBe(false);
|
||||
});
|
||||
|
||||
it("should return false when set is empty", () => {
|
||||
const planPreviewEventIds = new Set<string>();
|
||||
|
||||
expect(shouldShowPlanPreview("event-1", planPreviewEventIds)).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -40,18 +40,6 @@ import { conversationWebSocketTestSetup } from "./helpers/msw-websocket-setup";
|
||||
import { useEventStore } from "#/stores/use-event-store";
|
||||
import { isV1Event } from "#/types/v1/type-guards";
|
||||
|
||||
// Mock useUserConversation to return V1 conversation data
|
||||
vi.mock("#/hooks/query/use-user-conversation", () => ({
|
||||
useUserConversation: vi.fn(() => ({
|
||||
data: {
|
||||
conversation_version: "V1",
|
||||
status: "RUNNING",
|
||||
},
|
||||
isLoading: false,
|
||||
error: null,
|
||||
})),
|
||||
}));
|
||||
|
||||
// MSW WebSocket mock setup
|
||||
const { wsLink, server: mswServer } = conversationWebSocketTestSetup();
|
||||
|
||||
@@ -679,16 +667,6 @@ describe("Conversation WebSocket Handler", () => {
|
||||
|
||||
// Set up MSW to mock both the HTTP API and WebSocket connection
|
||||
mswServer.use(
|
||||
// Mock events search for history preloading
|
||||
http.get(
|
||||
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
|
||||
async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
return HttpResponse.json({
|
||||
items: mockHistoryEvents,
|
||||
});
|
||||
},
|
||||
),
|
||||
http.get(
|
||||
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
|
||||
() => HttpResponse.json(expectedEventCount),
|
||||
@@ -725,6 +703,11 @@ describe("Conversation WebSocket Handler", () => {
|
||||
`http://localhost:3000/api/conversations/${conversationId}`,
|
||||
);
|
||||
|
||||
// Initially should be loading history
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
|
||||
"true",
|
||||
);
|
||||
|
||||
// Wait for all events to be received
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("events-received")).toHaveTextContent("3");
|
||||
@@ -743,14 +726,6 @@ describe("Conversation WebSocket Handler", () => {
|
||||
|
||||
// Set up MSW to mock both the HTTP API and WebSocket connection
|
||||
mswServer.use(
|
||||
// Mock empty events search
|
||||
http.get(
|
||||
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
|
||||
() =>
|
||||
HttpResponse.json({
|
||||
items: [],
|
||||
}),
|
||||
),
|
||||
http.get(
|
||||
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
|
||||
() => HttpResponse.json(0),
|
||||
@@ -800,16 +775,6 @@ describe("Conversation WebSocket Handler", () => {
|
||||
|
||||
// Set up MSW to mock both the HTTP API and WebSocket connection
|
||||
mswServer.use(
|
||||
// Mock events search for history preloading (50 events)
|
||||
http.get(
|
||||
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
|
||||
async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
return HttpResponse.json({
|
||||
items: mockHistoryEvents,
|
||||
});
|
||||
},
|
||||
),
|
||||
http.get(
|
||||
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
|
||||
() => HttpResponse.json(expectedEventCount),
|
||||
@@ -845,6 +810,11 @@ describe("Conversation WebSocket Handler", () => {
|
||||
`http://localhost:3000/api/conversations/${conversationId}`,
|
||||
);
|
||||
|
||||
// Initially should be loading history
|
||||
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
|
||||
"true",
|
||||
);
|
||||
|
||||
// Wait for all events to be received
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("events-received")).toHaveTextContent("50");
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
import { describe, it, expect, afterEach, vi } from "vitest";
|
||||
import React from "react";
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
|
||||
import { useConversationHistory } from "#/hooks/query/use-conversation-history";
|
||||
import EventService from "#/api/event-service/event-service.api";
|
||||
import { useUserConversation } from "#/hooks/query/use-user-conversation";
|
||||
import type { Conversation } from "#/api/open-hands.types";
|
||||
import type { OpenHandsEvent } from "#/types/v1/core";
|
||||
|
||||
function makeConversation(version: "V0" | "V1"): Conversation {
|
||||
return {
|
||||
conversation_id: "conv-test",
|
||||
title: "Test Conversation",
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: null,
|
||||
last_updated_at: new Date().toISOString(),
|
||||
created_at: new Date().toISOString(),
|
||||
status: "RUNNING",
|
||||
runtime_status: null,
|
||||
url: null,
|
||||
session_api_key: null,
|
||||
conversation_version: version,
|
||||
};
|
||||
}
|
||||
|
||||
function makeEvent(): OpenHandsEvent {
|
||||
return {
|
||||
id: "evt-1",
|
||||
} as OpenHandsEvent;
|
||||
}
|
||||
|
||||
// --------------------
|
||||
// Mocks
|
||||
// --------------------
|
||||
vi.mock("#/api/open-hands-axios", () => ({
|
||||
openHands: {
|
||||
get: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("#/api/event-service/event-service.api");
|
||||
vi.mock("#/hooks/query/use-user-conversation");
|
||||
|
||||
const queryClient = new QueryClient();
|
||||
|
||||
function wrapper({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
}
|
||||
|
||||
// --------------------
|
||||
// Tests
|
||||
// --------------------
|
||||
describe("useConversationHistory", () => {
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("calls V1 REST endpoint for V1 conversations", async () => {
|
||||
const v1SearchEventsSpy = vi.spyOn(EventService, "searchEventsV1");
|
||||
|
||||
vi.mocked(useUserConversation).mockReturnValue({
|
||||
data: makeConversation("V1"),
|
||||
isLoading: false,
|
||||
isPending: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
v1SearchEventsSpy.mockResolvedValue([makeEvent()]);
|
||||
|
||||
const { result } = renderHook(() => useConversationHistory("conv-123"), {
|
||||
wrapper,
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.data).toBeDefined();
|
||||
});
|
||||
|
||||
expect(EventService.searchEventsV1).toHaveBeenCalledWith("conv-123");
|
||||
expect(EventService.searchEventsV0).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("calls V0 REST endpoint for V0 conversations", async () => {
|
||||
const v0SearchEventsSpy = vi.spyOn(EventService, "searchEventsV0");
|
||||
|
||||
vi.mocked(useUserConversation).mockReturnValue({
|
||||
data: makeConversation("V0"),
|
||||
isLoading: false,
|
||||
isPending: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
v0SearchEventsSpy.mockResolvedValue([makeEvent()]);
|
||||
|
||||
const { result } = renderHook(() => useConversationHistory("conv-456"), {
|
||||
wrapper,
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.data).toBeDefined();
|
||||
});
|
||||
|
||||
expect(EventService.searchEventsV0).toHaveBeenCalledWith("conv-456");
|
||||
expect(EventService.searchEventsV1).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -7,19 +7,14 @@ import LoginPage from "#/routes/login";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import AuthService from "#/api/auth-service/auth-service.api";
|
||||
|
||||
const { useEmailVerificationMock, resendEmailVerificationMock } = vi.hoisted(
|
||||
() => ({
|
||||
useEmailVerificationMock: vi.fn(() => ({
|
||||
emailVerified: false,
|
||||
hasDuplicatedEmail: false,
|
||||
emailVerificationModalOpen: false,
|
||||
setEmailVerificationModalOpen: vi.fn(),
|
||||
userId: null as string | null,
|
||||
resendEmailVerification: vi.fn(),
|
||||
})),
|
||||
resendEmailVerificationMock: vi.fn(),
|
||||
}),
|
||||
);
|
||||
const { useEmailVerificationMock } = vi.hoisted(() => ({
|
||||
useEmailVerificationMock: vi.fn(() => ({
|
||||
emailVerified: false,
|
||||
hasDuplicatedEmail: false,
|
||||
emailVerificationModalOpen: false,
|
||||
setEmailVerificationModalOpen: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-github-auth-url", () => ({
|
||||
useGitHubAuthUrl: () => "https://github.com/login/oauth/authorize",
|
||||
@@ -353,8 +348,6 @@ describe("LoginPage", () => {
|
||||
hasDuplicatedEmail: false,
|
||||
emailVerificationModalOpen: false,
|
||||
setEmailVerificationModalOpen: vi.fn(),
|
||||
userId: null,
|
||||
resendEmailVerification: resendEmailVerificationMock,
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
@@ -374,8 +367,6 @@ describe("LoginPage", () => {
|
||||
hasDuplicatedEmail: true,
|
||||
emailVerificationModalOpen: false,
|
||||
setEmailVerificationModalOpen: vi.fn(),
|
||||
userId: null,
|
||||
resendEmailVerification: resendEmailVerificationMock,
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
@@ -388,41 +379,6 @@ describe("LoginPage", () => {
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should pass userId to EmailVerificationModal when userId is provided", async () => {
|
||||
const user = userEvent.setup();
|
||||
const testUserId = "test-user-id-123";
|
||||
const setEmailVerificationModalOpen = vi.fn();
|
||||
|
||||
useEmailVerificationMock.mockReturnValue({
|
||||
emailVerified: false,
|
||||
hasDuplicatedEmail: false,
|
||||
emailVerificationModalOpen: true,
|
||||
setEmailVerificationModalOpen,
|
||||
userId: testUserId,
|
||||
resendEmailVerification: resendEmailVerificationMock,
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const resendButton = screen.getByRole("button", {
|
||||
name: /SETTINGS\$RESEND_VERIFICATION/i,
|
||||
});
|
||||
await user.click(resendButton);
|
||||
|
||||
expect(resendEmailVerificationMock).toHaveBeenCalledWith({
|
||||
userId: testUserId,
|
||||
isAuthFlow: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Loading States", () => {
|
||||
@@ -459,15 +415,6 @@ describe("LoginPage", () => {
|
||||
|
||||
describe("Terms and Privacy", () => {
|
||||
it("should display Terms and Privacy notice", async () => {
|
||||
useEmailVerificationMock.mockReturnValue({
|
||||
emailVerified: false,
|
||||
hasDuplicatedEmail: false,
|
||||
emailVerificationModalOpen: false,
|
||||
setEmailVerificationModalOpen: vi.fn(),
|
||||
userId: null as string | null,
|
||||
resendEmailVerification: resendEmailVerificationMock,
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
@@ -48,7 +48,6 @@ function LoginStub() {
|
||||
searchParams.get("email_verification_required") === "true";
|
||||
const emailVerified = searchParams.get("email_verified") === "true";
|
||||
const emailVerificationText = "AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY";
|
||||
const returnTo = searchParams.get("returnTo");
|
||||
|
||||
return (
|
||||
<div data-testid="login-page">
|
||||
@@ -59,7 +58,6 @@ function LoginStub() {
|
||||
{emailVerificationText}
|
||||
</div>
|
||||
)}
|
||||
{returnTo && <div data-testid="return-to-param">{returnTo}</div>}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@@ -102,27 +100,6 @@ const RouterStubWithLogin = createRoutesStub([
|
||||
},
|
||||
]);
|
||||
|
||||
const RouterStubWithDeviceVerify = createRoutesStub([
|
||||
{
|
||||
Component: MainApp,
|
||||
path: "/",
|
||||
children: [
|
||||
{
|
||||
Component: () => <div data-testid="outlet-content" />,
|
||||
path: "/",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="device-verify-page" />,
|
||||
path: "/oauth/device/verify",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
Component: LoginStub,
|
||||
path: "/login",
|
||||
},
|
||||
]);
|
||||
|
||||
const renderMainApp = (initialEntries: string[] = ["/"]) =>
|
||||
render(<RouterStub initialEntries={initialEntries} />, {
|
||||
wrapper: ({ children }) => (
|
||||
@@ -334,23 +311,5 @@ describe("MainApp", () => {
|
||||
{ timeout: 2000 },
|
||||
);
|
||||
});
|
||||
|
||||
it("should preserve query parameters in returnTo when redirecting to login", async () => {
|
||||
renderWithLoginStub(RouterStubWithDeviceVerify, [
|
||||
"/oauth/device/verify?user_code=F9XN6BKU",
|
||||
]);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(screen.getByTestId("login-page")).toBeInTheDocument();
|
||||
const returnToElement = screen.getByTestId("return-to-param");
|
||||
expect(returnToElement).toBeInTheDocument();
|
||||
expect(returnToElement.textContent).toBe(
|
||||
"/oauth/device/verify?user_code=F9XN6BKU",
|
||||
);
|
||||
},
|
||||
{ timeout: 2000 },
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -138,72 +138,4 @@ describe("handleEventForUI", () => {
|
||||
anotherActionEvent,
|
||||
]);
|
||||
});
|
||||
|
||||
it("should NOT replace ThinkAction with ThinkObservation", () => {
|
||||
const mockThinkAction: ActionEvent = {
|
||||
id: "test-think-action-1",
|
||||
timestamp: Date.now().toString(),
|
||||
source: "agent",
|
||||
thought: [{ type: "text", text: "I am thinking..." }],
|
||||
thinking_blocks: [],
|
||||
action: {
|
||||
kind: "ThinkAction",
|
||||
thought: "I am thinking...",
|
||||
},
|
||||
tool_name: "think",
|
||||
tool_call_id: "call_think_1",
|
||||
tool_call: {
|
||||
id: "call_think_1",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "think",
|
||||
arguments: "",
|
||||
},
|
||||
},
|
||||
llm_response_id: "response_think",
|
||||
security_risk: SecurityRisk.UNKNOWN,
|
||||
};
|
||||
|
||||
const mockThinkObservation: ObservationEvent = {
|
||||
id: "test-think-observation-1",
|
||||
timestamp: Date.now().toString(),
|
||||
source: "environment",
|
||||
tool_name: "think",
|
||||
tool_call_id: "call_think_1",
|
||||
observation: {
|
||||
kind: "ThinkObservation",
|
||||
content: [{ type: "text", text: "Your thought has been logged." }],
|
||||
},
|
||||
action_id: "test-think-action-1",
|
||||
};
|
||||
|
||||
const initialUiEvents = [mockMessageEvent, mockThinkAction];
|
||||
const result = handleEventForUI(mockThinkObservation, initialUiEvents);
|
||||
|
||||
// ThinkObservation should NOT be added - ThinkAction should remain
|
||||
expect(result).toEqual([mockMessageEvent, mockThinkAction]);
|
||||
expect(result).not.toBe(initialUiEvents);
|
||||
});
|
||||
|
||||
it("should NOT add ThinkObservation even when ThinkAction is not found", () => {
|
||||
const mockThinkObservation: ObservationEvent = {
|
||||
id: "test-think-observation-1",
|
||||
timestamp: Date.now().toString(),
|
||||
source: "environment",
|
||||
tool_name: "think",
|
||||
tool_call_id: "call_think_1",
|
||||
observation: {
|
||||
kind: "ThinkObservation",
|
||||
content: [{ type: "text", text: "Your thought has been logged." }],
|
||||
},
|
||||
action_id: "test-think-action-not-found",
|
||||
};
|
||||
|
||||
const initialUiEvents = [mockMessageEvent];
|
||||
const result = handleEventForUI(mockThinkObservation, initialUiEvents);
|
||||
|
||||
// ThinkObservation should never be added to uiEvents
|
||||
expect(result).toEqual([mockMessageEvent]);
|
||||
expect(result).not.toBe(initialUiEvents);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -103,7 +103,7 @@ export interface V1AppConversation {
|
||||
|
||||
export interface Skill {
|
||||
name: string;
|
||||
type: "repo" | "knowledge" | "agentskills";
|
||||
type: "repo" | "knowledge";
|
||||
content: string;
|
||||
triggers: string[];
|
||||
}
|
||||
|
||||
@@ -5,8 +5,6 @@ import type {
|
||||
ConfirmationResponseRequest,
|
||||
ConfirmationResponseResponse,
|
||||
} from "./event-service.types";
|
||||
import { openHands } from "../open-hands-axios";
|
||||
import { OpenHandsEvent } from "#/types/v1/core";
|
||||
|
||||
class EventService {
|
||||
/**
|
||||
@@ -63,27 +61,5 @@ class EventService {
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
// V1 conversations — App Server REST endpoint
|
||||
static async searchEventsV1(conversationId: string, limit = 100) {
|
||||
const { data } = await openHands.get<{
|
||||
items: OpenHandsEvent[];
|
||||
}>(`/api/v1/conversation/${conversationId}/events/search`, {
|
||||
params: { limit },
|
||||
});
|
||||
|
||||
return data.items;
|
||||
}
|
||||
|
||||
// V0 conversations — Legacy REST endpoint
|
||||
static async searchEventsV0(conversationId: string, limit = 100) {
|
||||
const { data } = await openHands.get<{
|
||||
events: OpenHandsEvent[];
|
||||
}>(`/api/conversations/${conversationId}/events`, {
|
||||
params: { limit },
|
||||
});
|
||||
|
||||
return data.events;
|
||||
}
|
||||
}
|
||||
export default EventService;
|
||||
|
||||
@@ -110,7 +110,7 @@ export interface InputMetadata {
|
||||
|
||||
export interface Microagent {
|
||||
name: string;
|
||||
type: "repo" | "knowledge" | "agentskills";
|
||||
type: "repo" | "knowledge";
|
||||
content: string;
|
||||
triggers: string[];
|
||||
}
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
import { useMemo } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { ArrowUpRight } from "lucide-react";
|
||||
import LessonPlanIcon from "#/icons/lesson-plan.svg?react";
|
||||
import { USE_PLANNING_AGENT } from "#/utils/feature-flags";
|
||||
import { Typography } from "#/ui/typography";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { MarkdownRenderer } from "#/components/features/markdown/markdown-renderer";
|
||||
|
||||
const MAX_CONTENT_LENGTH = 300;
|
||||
|
||||
interface PlanPreviewProps {
|
||||
/** Raw plan content from PLAN.md file */
|
||||
planContent?: string | null;
|
||||
title?: string;
|
||||
description?: string;
|
||||
onViewClick?: () => void;
|
||||
onBuildClick?: () => void;
|
||||
}
|
||||
|
||||
// TODO: Remove the hardcoded values and use the plan content from the conversation store
|
||||
/* eslint-disable i18next/no-literal-string */
|
||||
export function PlanPreview({
|
||||
planContent,
|
||||
title = "Improve Developer Onboarding and Examples",
|
||||
description = "Based on the analysis of Browser-Use's current documentation and examples, this plan addresses gaps in developer onboarding by creating a progressive learning path, troubleshooting resources, and practical examples that address real-world scenarios (like the LM Studio/local LLM integration issues encountered...",
|
||||
onViewClick,
|
||||
onBuildClick,
|
||||
}: PlanPreviewProps) {
|
||||
@@ -26,13 +24,6 @@ export function PlanPreview({
|
||||
|
||||
const shouldUsePlanningAgent = USE_PLANNING_AGENT();
|
||||
|
||||
// Truncate plan content for preview
|
||||
const truncatedContent = useMemo(() => {
|
||||
if (!planContent) return "";
|
||||
if (planContent.length <= MAX_CONTENT_LENGTH) return planContent;
|
||||
return `${planContent.slice(0, MAX_CONTENT_LENGTH)}...`;
|
||||
}, [planContent]);
|
||||
|
||||
if (!shouldUsePlanningAgent) {
|
||||
return null;
|
||||
}
|
||||
@@ -50,7 +41,6 @@ export function PlanPreview({
|
||||
type="button"
|
||||
onClick={onViewClick}
|
||||
className="flex items-center gap-1 hover:opacity-80 transition-opacity"
|
||||
data-testid="plan-preview-view-button"
|
||||
>
|
||||
<Typography.Text className="font-medium text-[11px] text-white tracking-[0.11px] leading-4">
|
||||
{t(I18nKey.COMMON$VIEW)}
|
||||
@@ -60,27 +50,16 @@ export function PlanPreview({
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div
|
||||
data-testid="plan-preview-content"
|
||||
className="flex flex-col gap-[10px] p-4 text-[15px] text-white leading-[29px]"
|
||||
>
|
||||
{truncatedContent && (
|
||||
<>
|
||||
<MarkdownRenderer includeStandard includeHeadings>
|
||||
{truncatedContent}
|
||||
</MarkdownRenderer>
|
||||
{planContent && planContent.length > MAX_CONTENT_LENGTH && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={onViewClick}
|
||||
className="text-[#4a67bd] cursor-pointer hover:underline text-left"
|
||||
data-testid="plan-preview-read-more-button"
|
||||
>
|
||||
{t(I18nKey.COMMON$READ_MORE)}
|
||||
</button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<div className="flex flex-col gap-[10px] p-4">
|
||||
<h3 className="font-bold text-[19px] text-white leading-[29px]">
|
||||
{title}
|
||||
</h3>
|
||||
<p className="text-[15px] text-white leading-[29px]">
|
||||
{description}
|
||||
<Typography.Text className="text-[#4a67bd] cursor-pointer hover:underline ml-1">
|
||||
{t(I18nKey.COMMON$READ_MORE)}
|
||||
</Typography.Text>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Footer */}
|
||||
@@ -89,7 +68,6 @@ export function PlanPreview({
|
||||
type="button"
|
||||
onClick={onBuildClick}
|
||||
className="bg-white flex items-center justify-center h-[26px] px-2 rounded-[4px] w-[93px] hover:opacity-90 transition-opacity cursor-pointer"
|
||||
data-testid="plan-preview-build-button"
|
||||
>
|
||||
<Typography.Text className="font-medium text-[14px] text-black leading-5">
|
||||
{t(I18nKey.COMMON$BUILD)}{" "}
|
||||
|
||||
@@ -11,15 +11,6 @@ interface SkillItemProps {
|
||||
}
|
||||
|
||||
export function SkillItem({ skill, isExpanded, onToggle }: SkillItemProps) {
|
||||
let skillTypeLabel: string;
|
||||
if (skill.type === "repo") {
|
||||
skillTypeLabel = "Repository";
|
||||
} else if (skill.type === "knowledge") {
|
||||
skillTypeLabel = "Knowledge";
|
||||
} else {
|
||||
skillTypeLabel = "AgentSkills";
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-md overflow-hidden">
|
||||
<button
|
||||
@@ -34,7 +25,7 @@ export function SkillItem({ skill, isExpanded, onToggle }: SkillItemProps) {
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
<Typography.Text className="px-2 py-1 text-xs rounded-full bg-gray-800 mr-2">
|
||||
{skillTypeLabel}
|
||||
{skill.type === "repo" ? "Repository" : "Knowledge"}
|
||||
</Typography.Text>
|
||||
<Typography.Text className="text-gray-300">
|
||||
{isExpanded ? (
|
||||
|
||||
@@ -27,11 +27,6 @@ export const shouldRenderEvent = (event: OpenHandsEvent) => {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Hide PlanningFileEditorAction - handled separately with PlanPreview component
|
||||
if (actionType === "PlanningFileEditorAction") {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,9 @@ import {
|
||||
isObservationEvent,
|
||||
isAgentErrorEvent,
|
||||
isUserMessageEvent,
|
||||
isPlanningFileEditorObservationEvent,
|
||||
} from "#/types/v1/type-guards";
|
||||
import { MicroagentStatus } from "#/types/microagent-status";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
import { useConversationStore } from "#/stores/conversation-store";
|
||||
// TODO: Implement V1 feedback functionality when API supports V1 event IDs
|
||||
// import { useFeedbackExists } from "#/hooks/query/use-feedback-exists";
|
||||
import {
|
||||
@@ -21,8 +19,6 @@ import {
|
||||
ThoughtEventMessage,
|
||||
} from "./event-message-components";
|
||||
import { createSkillReadyEvent } from "./event-content-helpers/create-skill-ready-event";
|
||||
import { PlanPreview } from "../../features/chat/plan-preview";
|
||||
import { shouldShowPlanPreview } from "./hooks/use-plan-preview-events";
|
||||
|
||||
interface EventMessageProps {
|
||||
event: OpenHandsEvent & { isFromPlanningAgent?: boolean };
|
||||
@@ -37,8 +33,6 @@ interface EventMessageProps {
|
||||
tooltip?: string;
|
||||
}>;
|
||||
isInLast10Actions: boolean;
|
||||
/** Set of event IDs that should render PlanPreview (one per user message phase) */
|
||||
planPreviewEventIds?: Set<string>;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -149,10 +143,8 @@ export function EventMessage({
|
||||
microagentPRUrl,
|
||||
actions,
|
||||
isInLast10Actions,
|
||||
planPreviewEventIds,
|
||||
}: EventMessageProps) {
|
||||
const { data: config } = useConfig();
|
||||
const { planContent } = useConversationStore();
|
||||
|
||||
// V1 events use string IDs, but useFeedbackExists expects number
|
||||
// For now, we'll skip feedback functionality for V1 events
|
||||
@@ -206,21 +198,6 @@ export function EventMessage({
|
||||
|
||||
// Observation events - find the corresponding action and render thought + observation
|
||||
if (isObservationEvent(event)) {
|
||||
// Handle PlanningFileEditorObservation specially
|
||||
if (isPlanningFileEditorObservationEvent(event)) {
|
||||
// Only show PlanPreview if this event is marked as the one to display
|
||||
// (last PlanningFileEditorObservation in its phase)
|
||||
if (
|
||||
planPreviewEventIds &&
|
||||
shouldShowPlanPreview(event.id, planPreviewEventIds)
|
||||
) {
|
||||
return <PlanPreview planContent={planContent} />;
|
||||
}
|
||||
// Not the designated preview event for this phase - render nothing
|
||||
// This prevents duplicate previews within the same phase
|
||||
return null;
|
||||
}
|
||||
|
||||
// Find the action that this observation is responding to
|
||||
const correspondingAction = messages.find(
|
||||
(msg) => isActionEvent(msg) && msg.id === event.action_id,
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
import { useMemo } from "react";
|
||||
import { OpenHandsEvent } from "#/types/v1/core";
|
||||
import {
|
||||
isUserMessageEvent,
|
||||
isPlanningFileEditorObservationEvent,
|
||||
} from "#/types/v1/type-guards";
|
||||
|
||||
/**
|
||||
* Groups events into phases based on user messages.
|
||||
* A phase starts with a user message and includes all subsequent events
|
||||
* until the next user message.
|
||||
*
|
||||
* @param events - The full list of events
|
||||
* @returns Array of phases, where each phase is an array of events
|
||||
*/
|
||||
function groupEventsByPhase(events: OpenHandsEvent[]): OpenHandsEvent[][] {
|
||||
const phases: OpenHandsEvent[][] = [];
|
||||
let currentPhase: OpenHandsEvent[] = [];
|
||||
|
||||
for (const event of events) {
|
||||
if (isUserMessageEvent(event)) {
|
||||
// Start a new phase with the user message
|
||||
if (currentPhase.length > 0) {
|
||||
phases.push(currentPhase);
|
||||
}
|
||||
currentPhase = [event];
|
||||
} else {
|
||||
// Add event to current phase
|
||||
currentPhase.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
// Don't forget the last phase
|
||||
if (currentPhase.length > 0) {
|
||||
phases.push(currentPhase);
|
||||
}
|
||||
|
||||
return phases;
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the last PlanningFileEditorObservation in a phase.
|
||||
*
|
||||
* @param phase - Array of events in a phase
|
||||
* @returns The event ID of the last PlanningFileEditorObservation, or null
|
||||
*/
|
||||
function findLastPlanningObservationInPhase(
|
||||
phase: OpenHandsEvent[],
|
||||
): string | null {
|
||||
// Iterate backwards to find the last one
|
||||
for (let i = phase.length - 1; i >= 0; i -= 1) {
|
||||
const event = phase[i];
|
||||
if (isPlanningFileEditorObservationEvent(event)) {
|
||||
return event.id;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export interface PlanPreviewEventInfo {
|
||||
eventId: string;
|
||||
/** Index of this plan preview in the conversation (1st, 2nd, etc.) */
|
||||
phaseIndex: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to determine which PlanningFileEditorObservation events should render PlanPreview.
|
||||
*
|
||||
* This hook implements phase-based grouping where:
|
||||
* - A phase starts with a user message and ends at the next user message
|
||||
* - Only the LAST PlanningFileEditorObservation in each phase shows PlanPreview
|
||||
* - This ensures only one preview per user request, even with multiple observations
|
||||
*
|
||||
* Scenario handling:
|
||||
* - Scenario 1 (Create plan): Multiple observations in one phase → 1 preview
|
||||
* - Scenario 2 (Create then update): Two user messages → two phases → 2 previews
|
||||
* - Scenario 3 (Create + update while processing): Two user messages → 2 previews
|
||||
*
|
||||
* @param allEvents - Full list of v1 events (for phase detection)
|
||||
* @returns Set of event IDs that should render PlanPreview
|
||||
*/
|
||||
export function usePlanPreviewEvents(allEvents: OpenHandsEvent[]): Set<string> {
|
||||
return useMemo(() => {
|
||||
const planPreviewEventIds = new Set<string>();
|
||||
|
||||
// Group events by phases (user message boundaries)
|
||||
const phases = groupEventsByPhase(allEvents);
|
||||
|
||||
// For each phase, find the last PlanningFileEditorObservation
|
||||
phases.forEach((phase) => {
|
||||
const lastPlanningObservationId =
|
||||
findLastPlanningObservationInPhase(phase);
|
||||
if (lastPlanningObservationId) {
|
||||
planPreviewEventIds.add(lastPlanningObservationId);
|
||||
}
|
||||
});
|
||||
|
||||
return planPreviewEventIds;
|
||||
}, [allEvents]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a specific event should render PlanPreview.
|
||||
*
|
||||
* @param eventId - The event ID to check
|
||||
* @param planPreviewEventIds - Set of event IDs that should render PlanPreview
|
||||
* @returns true if this event should render PlanPreview
|
||||
*/
|
||||
export function shouldShowPlanPreview(
|
||||
eventId: string,
|
||||
planPreviewEventIds: Set<string>,
|
||||
): boolean {
|
||||
return planPreviewEventIds.has(eventId);
|
||||
}
|
||||
@@ -3,7 +3,6 @@ import { OpenHandsEvent } from "#/types/v1/core";
|
||||
import { EventMessage } from "./event-message";
|
||||
import { ChatMessage } from "../../features/chat/chat-message";
|
||||
import { useOptimisticUserMessageStore } from "#/stores/optimistic-user-message-store";
|
||||
import { usePlanPreviewEvents } from "./hooks/use-plan-preview-events";
|
||||
// TODO: Implement microagent functionality for V1 when APIs support V1 event IDs
|
||||
// import { AgentState } from "#/types/agent-state";
|
||||
// import MemoryIcon from "#/icons/memory_icon.svg?react";
|
||||
@@ -19,10 +18,6 @@ export const Messages: React.FC<MessagesProps> = React.memo(
|
||||
|
||||
const optimisticUserMessage = getOptimisticUserMessage();
|
||||
|
||||
// Get the set of event IDs that should render PlanPreview
|
||||
// This ensures only one preview per user message "phase"
|
||||
const planPreviewEventIds = usePlanPreviewEvents(allEvents);
|
||||
|
||||
// TODO: Implement microagent functionality for V1 if needed
|
||||
// For now, we'll skip microagent features
|
||||
|
||||
@@ -35,7 +30,6 @@ export const Messages: React.FC<MessagesProps> = React.memo(
|
||||
messages={allEvents}
|
||||
isLastMessage={messages.length - 1 === index}
|
||||
isInLast10Actions={messages.length - 1 - index < 10}
|
||||
planPreviewEventIds={planPreviewEventIds}
|
||||
// Microagent props - not implemented yet for V1
|
||||
// microagentStatus={undefined}
|
||||
// microagentConversationId={undefined}
|
||||
|
||||
@@ -46,7 +46,6 @@ import { useTracking } from "#/hooks/use-tracking";
|
||||
import { useReadConversationFile } from "#/hooks/mutation/use-read-conversation-file";
|
||||
import useMetricsStore from "#/stores/metrics-store";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { useConversationHistory } from "#/hooks/query/use-conversation-history";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
export type V1_WebSocketConnectionState =
|
||||
@@ -307,21 +306,6 @@ export function ConversationWebSocketProvider({
|
||||
latestPlanningFileEventRef.current = null;
|
||||
}, [conversationId]);
|
||||
|
||||
const { data: preloadedEvents } = useConversationHistory(conversationId);
|
||||
|
||||
useEffect(() => {
|
||||
if (!preloadedEvents || preloadedEvents.length === 0) {
|
||||
setIsLoadingHistoryMain(false);
|
||||
return;
|
||||
}
|
||||
|
||||
for (const event of preloadedEvents) {
|
||||
addEvent(event);
|
||||
}
|
||||
|
||||
setIsLoadingHistoryMain(false);
|
||||
}, [preloadedEvents, addEvent]);
|
||||
|
||||
// Separate message handlers for each connection
|
||||
const handleMainMessage = useCallback(
|
||||
(messageEvent: MessageEvent) => {
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import EventService from "#/api/event-service/event-service.api";
|
||||
import { useUserConversation } from "#/hooks/query/use-user-conversation";
|
||||
|
||||
export const useConversationHistory = (conversationId?: string) => {
|
||||
const { data: conversation } = useUserConversation(conversationId ?? null);
|
||||
|
||||
return useQuery({
|
||||
queryKey: ["conversation-history", conversationId, conversation],
|
||||
enabled: !!conversationId && !!conversation,
|
||||
queryFn: async () => {
|
||||
if (!conversationId || !conversation) return [];
|
||||
|
||||
if (conversation.conversation_version === "V1") {
|
||||
return EventService.searchEventsV1(conversationId);
|
||||
}
|
||||
|
||||
return EventService.searchEventsV0(conversationId);
|
||||
},
|
||||
staleTime: 30_000,
|
||||
});
|
||||
};
|
||||
@@ -20,7 +20,6 @@ export default function LoginPage() {
|
||||
recaptchaBlocked,
|
||||
emailVerificationModalOpen,
|
||||
setEmailVerificationModalOpen,
|
||||
userId,
|
||||
} = useEmailVerification();
|
||||
|
||||
const gitHubAuthUrl = useGitHubAuthUrl({
|
||||
@@ -78,7 +77,6 @@ export default function LoginPage() {
|
||||
onClose={() => {
|
||||
setEmailVerificationModalOpen(false);
|
||||
}}
|
||||
userId={userId}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -5,7 +5,6 @@ import {
|
||||
Outlet,
|
||||
useNavigate,
|
||||
useLocation,
|
||||
useSearchParams,
|
||||
} from "react-router";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
@@ -68,7 +67,6 @@ export default function MainApp() {
|
||||
const appTitle = useAppTitle();
|
||||
const navigate = useNavigate();
|
||||
const { pathname } = useLocation();
|
||||
const [searchParams] = useSearchParams();
|
||||
const isOnTosPage = useIsOnTosPage();
|
||||
const { data: settings } = useSettings();
|
||||
const { migrateUserConsent } = useMigrateUserConsent();
|
||||
@@ -184,18 +182,13 @@ export default function MainApp() {
|
||||
|
||||
React.useEffect(() => {
|
||||
if (shouldRedirectToLogin) {
|
||||
// Include search params in returnTo to preserve query string (e.g., user_code for device OAuth)
|
||||
const searchString = searchParams.toString();
|
||||
let fullPath = "";
|
||||
if (pathname !== "/") {
|
||||
fullPath = searchString ? `${pathname}?${searchString}` : pathname;
|
||||
}
|
||||
const loginUrl = fullPath
|
||||
? `/login?returnTo=${encodeURIComponent(fullPath)}`
|
||||
const returnTo = pathname !== "/" ? pathname : "";
|
||||
const loginUrl = returnTo
|
||||
? `/login?returnTo=${encodeURIComponent(returnTo)}`
|
||||
: "/login";
|
||||
navigate(loginUrl, { replace: true });
|
||||
}
|
||||
}, [shouldRedirectToLogin, pathname, searchParams, navigate]);
|
||||
}, [shouldRedirectToLogin, pathname, navigate]);
|
||||
|
||||
if (shouldRedirectToLogin) {
|
||||
return (
|
||||
|
||||
@@ -213,37 +213,6 @@ export interface BrowserCloseTabAction extends ActionBase<"BrowserCloseTabAction
|
||||
tab_id: string;
|
||||
}
|
||||
|
||||
export interface PlanningFileEditorAction extends ActionBase<"PlanningFileEditorAction"> {
|
||||
/**
|
||||
* The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.
|
||||
*/
|
||||
command: "view" | "create" | "str_replace" | "insert" | "undo_edit";
|
||||
/**
|
||||
* Absolute path to file (typically /workspace/project/PLAN.md).
|
||||
*/
|
||||
path: string;
|
||||
/**
|
||||
* Required parameter of `create` command, with the content of the file to be created.
|
||||
*/
|
||||
file_text: string | null;
|
||||
/**
|
||||
* Required parameter of `str_replace` command containing the string in `path` to replace.
|
||||
*/
|
||||
old_str: string | null;
|
||||
/**
|
||||
* Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.
|
||||
*/
|
||||
new_str: string | null;
|
||||
/**
|
||||
* Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`. Must be >= 1.
|
||||
*/
|
||||
insert_line: number | null;
|
||||
/**
|
||||
* Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown.
|
||||
*/
|
||||
view_range: [number, number] | null;
|
||||
}
|
||||
|
||||
export type Action =
|
||||
| MCPToolAction
|
||||
| FinishAction
|
||||
@@ -253,7 +222,6 @@ export type Action =
|
||||
| FileEditorAction
|
||||
| StrReplaceEditorAction
|
||||
| TaskTrackerAction
|
||||
| PlanningFileEditorAction
|
||||
| BrowserNavigateAction
|
||||
| BrowserClickAction
|
||||
| BrowserTypeAction
|
||||
|
||||
@@ -4,7 +4,6 @@ import { isObservationEvent } from "#/types/v1/type-guards";
|
||||
/**
|
||||
* Handles adding an event to the UI events array
|
||||
* Replaces actions with observations when they arrive (so UI shows observation instead of action)
|
||||
* Exception: ThinkAction is NOT replaced because the thought content is in the action, not in the observation
|
||||
*/
|
||||
export const handleEventForUI = (
|
||||
event: OpenHandsEvent,
|
||||
@@ -13,17 +12,12 @@ export const handleEventForUI = (
|
||||
const newUiEvents = [...uiEvents];
|
||||
|
||||
if (isObservationEvent(event)) {
|
||||
// Don't add ThinkObservation at all - we keep the ThinkAction instead
|
||||
// The thought content is in the action, not the observation
|
||||
if (event.observation.kind === "ThinkObservation") {
|
||||
return newUiEvents;
|
||||
}
|
||||
|
||||
// Find and replace the corresponding action from uiEvents
|
||||
const actionIndex = newUiEvents.findIndex(
|
||||
(uiEvent) => uiEvent.id === event.action_id,
|
||||
);
|
||||
if (actionIndex !== -1) {
|
||||
// Replace the action with the observation
|
||||
newUiEvents[actionIndex] = event;
|
||||
} else {
|
||||
// Action not found in uiEvents, just add the observation
|
||||
|
||||
@@ -7,13 +7,6 @@ import { I18nextProvider, initReactI18next } from "react-i18next";
|
||||
import i18n from "i18next";
|
||||
import { vi } from "vitest";
|
||||
import { AxiosError } from "axios";
|
||||
import {
|
||||
ActionEvent,
|
||||
MessageEvent,
|
||||
ObservationEvent,
|
||||
PlanningFileEditorObservation,
|
||||
} from "#/types/v1/core";
|
||||
import { SecurityRisk } from "#/types/v1/core";
|
||||
|
||||
export const useParamsMock = vi.fn(() => ({
|
||||
conversationId: "test-conversation-id",
|
||||
@@ -105,100 +98,3 @@ export const createAxiosError = (
|
||||
config: {},
|
||||
},
|
||||
);
|
||||
|
||||
// Helper to create a PlanningFileEditorAction event
|
||||
export const createPlanningFileEditorActionEvent = (
|
||||
id: string,
|
||||
): ActionEvent => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "agent",
|
||||
thought: [{ type: "text", text: "Planning action" }],
|
||||
thinking_blocks: [],
|
||||
action: {
|
||||
kind: "PlanningFileEditorAction",
|
||||
command: "create",
|
||||
path: "/workspace/PLAN.md",
|
||||
file_text: "Plan content",
|
||||
old_str: null,
|
||||
new_str: null,
|
||||
insert_line: null,
|
||||
view_range: null,
|
||||
},
|
||||
tool_name: "planning_file_editor",
|
||||
tool_call_id: "call-1",
|
||||
tool_call: {
|
||||
id: "call-1",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "planning_file_editor",
|
||||
arguments: '{"command": "create"}',
|
||||
},
|
||||
},
|
||||
llm_response_id: "response-1",
|
||||
security_risk: SecurityRisk.UNKNOWN,
|
||||
});
|
||||
|
||||
// Helper to create a non-planning action event
|
||||
export const createOtherActionEvent = (id: string): ActionEvent => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "agent",
|
||||
thought: [{ type: "text", text: "Other action" }],
|
||||
thinking_blocks: [],
|
||||
action: {
|
||||
kind: "ExecuteBashAction",
|
||||
command: "echo test",
|
||||
is_input: false,
|
||||
timeout: null,
|
||||
reset: false,
|
||||
},
|
||||
tool_name: "execute_bash",
|
||||
tool_call_id: "call-1",
|
||||
tool_call: {
|
||||
id: "call-1",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "execute_bash",
|
||||
arguments: '{"command": "echo test"}',
|
||||
},
|
||||
},
|
||||
llm_response_id: "response-1",
|
||||
security_risk: SecurityRisk.UNKNOWN,
|
||||
});
|
||||
|
||||
// Helper to create a PlanningFileEditorObservation event
|
||||
export const createPlanningObservationEvent = (
|
||||
id: string,
|
||||
actionId: string = "action-1",
|
||||
): ObservationEvent<PlanningFileEditorObservation> => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "environment",
|
||||
tool_name: "planning_file_editor",
|
||||
tool_call_id: "call-1",
|
||||
action_id: actionId,
|
||||
observation: {
|
||||
kind: "PlanningFileEditorObservation",
|
||||
content: [{ type: "text", text: "Plan content" }],
|
||||
is_error: false,
|
||||
command: "create",
|
||||
path: "/workspace/PLAN.md",
|
||||
prev_exist: false,
|
||||
old_content: null,
|
||||
new_content: "Plan content",
|
||||
},
|
||||
});
|
||||
|
||||
// Helper to create a user message event
|
||||
export const createUserMessageEvent = (id: string): MessageEvent => ({
|
||||
id,
|
||||
timestamp: new Date().toISOString(),
|
||||
source: "user",
|
||||
llm_message: {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "User message" }],
|
||||
},
|
||||
activated_microagents: [],
|
||||
extended_content: [],
|
||||
});
|
||||
|
||||
@@ -1,8 +1,18 @@
|
||||
# OpenHands Architecture
|
||||
# OpenHands
|
||||
|
||||
This directory contains the core components of OpenHands.
|
||||
|
||||
For an overview of the system architecture, see the [architecture documentation](https://docs.openhands.dev/usage/architecture/backend) (v0 backend architecture).
|
||||
## Documentation
|
||||
|
||||
- **[Architecture Documentation](./architecture/README.md)** - Detailed system architecture with Mermaid diagrams covering:
|
||||
- System Architecture Overview
|
||||
- Conversation Startup & WebSocket Flow
|
||||
- Authentication Flow (Keycloak)
|
||||
- Agent Execution & LLM Flow
|
||||
- External Integrations (GitHub, Slack, Jira, etc.)
|
||||
- Metrics, Logs & Observability
|
||||
|
||||
- **[External Architecture Docs](https://docs.openhands.dev/usage/architecture/backend)** - Official documentation (v0 backend architecture)
|
||||
|
||||
## Classes
|
||||
|
||||
|
||||
@@ -176,6 +176,6 @@ class SkillResponse(BaseModel):
|
||||
"""Response model for skills endpoint."""
|
||||
|
||||
name: str
|
||||
type: Literal['repo', 'knowledge', 'agentskills']
|
||||
type: Literal['repo', 'knowledge']
|
||||
content: str
|
||||
triggers: list[str] = []
|
||||
|
||||
@@ -503,6 +503,13 @@ async def get_conversation_skills(
|
||||
|
||||
agent_server_url = replace_localhost_hostname_for_docker(agent_server_url)
|
||||
|
||||
# Create remote workspace
|
||||
remote_workspace = AsyncRemoteWorkspace(
|
||||
host=agent_server_url,
|
||||
api_key=sandbox.session_api_key,
|
||||
working_dir=sandbox_spec.working_dir,
|
||||
)
|
||||
|
||||
# Load skills from all sources
|
||||
logger.info(f'Loading skills for conversation {conversation_id}')
|
||||
|
||||
@@ -511,9 +518,9 @@ async def get_conversation_skills(
|
||||
if isinstance(app_conversation_service, AppConversationServiceBase):
|
||||
all_skills = await app_conversation_service.load_and_merge_all_skills(
|
||||
sandbox,
|
||||
remote_workspace,
|
||||
conversation.selected_repository,
|
||||
sandbox_spec.working_dir,
|
||||
agent_server_url,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -524,11 +531,9 @@ async def get_conversation_skills(
|
||||
# Transform skills to response format
|
||||
skills_response = []
|
||||
for skill in all_skills:
|
||||
# Determine type based on AgentSkills format and trigger
|
||||
skill_type: Literal['repo', 'knowledge', 'agentskills']
|
||||
if skill.is_agentskills_format:
|
||||
skill_type = 'agentskills'
|
||||
elif skill.trigger is None:
|
||||
# Determine type based on trigger
|
||||
skill_type: Literal['repo', 'knowledge']
|
||||
if skill.trigger is None:
|
||||
skill_type = 'repo'
|
||||
else:
|
||||
skill_type = 'knowledge'
|
||||
|
||||
@@ -95,7 +95,6 @@ class AppConversationService(ABC):
|
||||
task: AppConversationStartTask,
|
||||
sandbox: SandboxInfo,
|
||||
workspace: AsyncRemoteWorkspace,
|
||||
agent_server_url: str,
|
||||
) -> AsyncGenerator[AppConversationStartTask, None]:
|
||||
"""Run the setup scripts for the project and yield status updates"""
|
||||
yield task
|
||||
|
||||
@@ -21,16 +21,18 @@ from openhands.app_server.app_conversation.app_conversation_service import (
|
||||
AppConversationService,
|
||||
)
|
||||
from openhands.app_server.app_conversation.skill_loader import (
|
||||
build_org_config,
|
||||
build_sandbox_config,
|
||||
load_skills_from_agent_server,
|
||||
load_global_skills,
|
||||
load_org_skills,
|
||||
load_repo_skills,
|
||||
load_sandbox_skills,
|
||||
merge_skills,
|
||||
)
|
||||
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.sdk import Agent
|
||||
from openhands.sdk.context.agent_context import AgentContext
|
||||
from openhands.sdk.context.condenser import LLMSummarizingCondenser
|
||||
from openhands.sdk.context.skills import Skill
|
||||
from openhands.sdk.context.skills import load_user_skills
|
||||
from openhands.sdk.llm import LLM
|
||||
from openhands.sdk.security.analyzer import SecurityAnalyzerBase
|
||||
from openhands.sdk.security.confirmation_policy import (
|
||||
@@ -59,74 +61,67 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
async def load_and_merge_all_skills(
|
||||
self,
|
||||
sandbox: SandboxInfo,
|
||||
remote_workspace: AsyncRemoteWorkspace,
|
||||
selected_repository: str | None,
|
||||
working_dir: str,
|
||||
agent_server_url: str,
|
||||
) -> list[Skill]:
|
||||
"""Load skills from all sources via the agent-server.
|
||||
) -> list:
|
||||
"""Load skills from all sources and merge them.
|
||||
|
||||
This method calls the agent-server's /api/skills endpoint to load and
|
||||
merge skills from all sources. The agent-server handles:
|
||||
- Public skills (from OpenHands/skills GitHub repo)
|
||||
- User skills (from ~/.openhands/skills/)
|
||||
- Organization skills (from {org}/.openhands repo)
|
||||
- Project/repo skills (from workspace .openhands/skills/)
|
||||
- Sandbox skills (from exposed URLs)
|
||||
This method handles all errors gracefully and will return an empty list
|
||||
if skill loading fails completely.
|
||||
|
||||
Args:
|
||||
sandbox: SandboxInfo containing exposed URLs and agent-server URL
|
||||
remote_workspace: AsyncRemoteWorkspace for loading repo skills
|
||||
selected_repository: Repository name or None
|
||||
working_dir: Working directory path
|
||||
agent_server_url: Agent-server URL (required)
|
||||
|
||||
Returns:
|
||||
List of merged Skill objects from all sources, or empty list on failure
|
||||
"""
|
||||
try:
|
||||
_logger.debug('Loading skills for V1 conversation via agent-server')
|
||||
_logger.debug('Loading skills for V1 conversation')
|
||||
|
||||
if not agent_server_url:
|
||||
_logger.warning('No agent-server URL available, cannot load skills')
|
||||
return []
|
||||
# Load skills from all sources
|
||||
sandbox_skills = load_sandbox_skills(sandbox)
|
||||
global_skills = load_global_skills()
|
||||
# Load user skills from ~/.openhands/skills/ directory
|
||||
# Uses the SDK's load_user_skills() function which handles loading from
|
||||
# ~/.openhands/skills/ and ~/.openhands/microagents/ (for backward compatibility)
|
||||
try:
|
||||
user_skills = load_user_skills()
|
||||
_logger.info(
|
||||
f'Loaded {len(user_skills)} user skills: {[s.name for s in user_skills]}'
|
||||
)
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to load user skills: {str(e)}')
|
||||
user_skills = []
|
||||
|
||||
# Build org config (authentication handled by app-server)
|
||||
org_config = await build_org_config(selected_repository, self.user_context)
|
||||
# Load organization-level skills
|
||||
org_skills = await load_org_skills(
|
||||
remote_workspace, selected_repository, working_dir, self.user_context
|
||||
)
|
||||
|
||||
# Build sandbox config (exposed URLs)
|
||||
sandbox_config = build_sandbox_config(sandbox)
|
||||
repo_skills = await load_repo_skills(
|
||||
remote_workspace, selected_repository, working_dir
|
||||
)
|
||||
|
||||
# Determine project directory for project skills
|
||||
project_dir = working_dir
|
||||
if selected_repository:
|
||||
repo_name = selected_repository.split('/')[-1]
|
||||
project_dir = f'{working_dir}/{repo_name}'
|
||||
|
||||
# Single API call to agent-server for ALL skills
|
||||
all_skills = await load_skills_from_agent_server(
|
||||
agent_server_url=agent_server_url,
|
||||
session_api_key=sandbox.session_api_key,
|
||||
project_dir=project_dir,
|
||||
org_config=org_config,
|
||||
sandbox_config=sandbox_config,
|
||||
load_public=True,
|
||||
load_user=True,
|
||||
load_project=True,
|
||||
load_org=True,
|
||||
# Merge all skills (later lists override earlier ones)
|
||||
# Precedence: sandbox < global < user < org < repo
|
||||
all_skills = merge_skills(
|
||||
[sandbox_skills, global_skills, user_skills, org_skills, repo_skills]
|
||||
)
|
||||
|
||||
_logger.info(
|
||||
f'Loaded {len(all_skills)} total skills from agent-server: '
|
||||
f'{[s.name for s in all_skills]}'
|
||||
f'Loaded {len(all_skills)} total skills: {[s.name for s in all_skills]}'
|
||||
)
|
||||
|
||||
return all_skills
|
||||
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to load skills: {e}', exc_info=True)
|
||||
# Return empty list on failure - skills will be loaded again later if needed
|
||||
return []
|
||||
|
||||
def _create_agent_with_skills(self, agent, skills: list[Skill]):
|
||||
def _create_agent_with_skills(self, agent, skills: list):
|
||||
"""Create or update agent with skills in its context.
|
||||
|
||||
Args:
|
||||
@@ -137,9 +132,9 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
Updated agent with skills in context
|
||||
"""
|
||||
if agent.agent_context:
|
||||
# Merge with existing context (new skills override existing ones)
|
||||
# Merge with existing context
|
||||
existing_skills = agent.agent_context.skills
|
||||
all_skills = self._merge_skills([existing_skills, skills])
|
||||
all_skills = merge_skills([skills, existing_skills])
|
||||
agent = agent.model_copy(
|
||||
update={
|
||||
'agent_context': agent.agent_context.model_copy(
|
||||
@@ -154,25 +149,6 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
|
||||
return agent
|
||||
|
||||
def _merge_skills(self, skill_lists: list[list[Skill]]) -> list[Skill]:
|
||||
"""Merge multiple skill lists, avoiding duplicates by name.
|
||||
|
||||
Later lists take precedence over earlier lists for duplicate names.
|
||||
|
||||
Args:
|
||||
skill_lists: List of skill lists to merge
|
||||
|
||||
Returns:
|
||||
Deduplicated list of skills with later lists overriding earlier ones
|
||||
"""
|
||||
skills_by_name: dict[str, Skill] = {}
|
||||
|
||||
for skill_list in skill_lists:
|
||||
for skill in skill_list:
|
||||
skills_by_name[skill.name] = skill
|
||||
|
||||
return list(skills_by_name.values())
|
||||
|
||||
async def _load_skills_and_update_agent(
|
||||
self,
|
||||
sandbox: SandboxInfo,
|
||||
@@ -193,10 +169,8 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
Updated agent with skills loaded into context
|
||||
"""
|
||||
# Load and merge all skills
|
||||
# Extract agent_server_url from remote_workspace host
|
||||
agent_server_url = remote_workspace.host
|
||||
all_skills = await self.load_and_merge_all_skills(
|
||||
sandbox, selected_repository, working_dir, agent_server_url
|
||||
sandbox, remote_workspace, selected_repository, working_dir
|
||||
)
|
||||
|
||||
# Update agent with skills
|
||||
@@ -209,7 +183,6 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
task: AppConversationStartTask,
|
||||
sandbox: SandboxInfo,
|
||||
workspace: AsyncRemoteWorkspace,
|
||||
agent_server_url: str,
|
||||
) -> AsyncGenerator[AppConversationStartTask, None]:
|
||||
task.status = AppConversationStartTaskStatus.PREPARING_REPOSITORY
|
||||
yield task
|
||||
@@ -227,9 +200,9 @@ class AppConversationServiceBase(AppConversationService, ABC):
|
||||
yield task
|
||||
await self.load_and_merge_all_skills(
|
||||
sandbox,
|
||||
workspace,
|
||||
task.request.selected_repository,
|
||||
workspace.working_dir,
|
||||
agent_server_url,
|
||||
)
|
||||
|
||||
async def _configure_git_user_settings(
|
||||
|
||||
@@ -237,7 +237,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
working_dir=sandbox_spec.working_dir,
|
||||
)
|
||||
async for updated_task in self.run_setup_scripts(
|
||||
task, sandbox, remote_workspace, agent_server_url
|
||||
task, sandbox, remote_workspace
|
||||
):
|
||||
yield updated_task
|
||||
|
||||
@@ -1295,7 +1295,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
# Get all events for this conversation
|
||||
i = 0
|
||||
async for event in page_iterator(
|
||||
self.event_service.search_events, conversation_id=conversation_id
|
||||
self.event_service.search_events, conversation_id__eq=conversation_id
|
||||
):
|
||||
event_filename = f'event_{i:06d}_{event.id}.json'
|
||||
event_path = os.path.join(temp_dir, event_filename)
|
||||
|
||||
@@ -1,37 +1,34 @@
|
||||
"""Utilities for loading skills for V1 conversations.
|
||||
|
||||
This module provides functions to load skills from the agent-server,
|
||||
which centralizes all skill loading logic. The app-server acts as a
|
||||
thin proxy that:
|
||||
1. Builds the org_config with authentication information
|
||||
2. Builds the sandbox_config with exposed URLs
|
||||
3. Calls the agent-server's /api/skills endpoint
|
||||
This module provides functions to load skills from various sources:
|
||||
- Global skills from OpenHands/skills/
|
||||
- User skills from ~/.openhands/skills/
|
||||
- Repository-level skills from the workspace
|
||||
|
||||
All source-specific skill loading is handled by the agent-server.
|
||||
All skills are used in V1 conversations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import openhands
|
||||
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
from openhands.sdk.context.skills import Skill
|
||||
from openhands.sdk.context.skills.trigger import KeywordTrigger, TaskTrigger
|
||||
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExposedUrlConfig(BaseModel):
|
||||
"""Configuration for an exposed URL in sandbox config."""
|
||||
|
||||
name: str
|
||||
url: str
|
||||
port: int
|
||||
|
||||
# Path to global skills directory
|
||||
GLOBAL_SKILLS_DIR = os.path.join(
|
||||
os.path.dirname(os.path.dirname(openhands.__file__)),
|
||||
'skills',
|
||||
)
|
||||
WORK_HOSTS_SKILL = """The user has access to the following hosts for accessing a web application,
|
||||
each of which has a corresponding port:"""
|
||||
|
||||
WORK_HOSTS_SKILL_FOOTER = """
|
||||
When starting a web server, use the corresponding ports via environment variables:
|
||||
@@ -48,30 +45,96 @@ app.run(host='0.0.0.0', port=int(os.environ.get('WORKER_1', 12000)))
|
||||
```"""
|
||||
|
||||
|
||||
class SandboxConfig(BaseModel):
|
||||
"""Sandbox configuration for agent-server API request."""
|
||||
def _find_and_load_global_skill_files(skill_dir: Path) -> list[Skill]:
|
||||
"""Find and load all .md files from the global skills directory.
|
||||
|
||||
exposed_urls: list[ExposedUrlConfig]
|
||||
Args:
|
||||
skill_dir: Path to the global skills directory
|
||||
|
||||
Returns:
|
||||
List of Skill objects loaded from the files (excluding README.md)
|
||||
"""
|
||||
skills = []
|
||||
|
||||
try:
|
||||
# Find all .md files in the directory (excluding README.md)
|
||||
md_files = [f for f in skill_dir.glob('*.md') if f.name.lower() != 'readme.md']
|
||||
|
||||
# Load skills from the found files
|
||||
for file_path in md_files:
|
||||
try:
|
||||
skill = Skill.load(file_path, skill_dir)
|
||||
skills.append(skill)
|
||||
_logger.debug(f'Loaded global skill: {skill.name} from {file_path}')
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
f'Failed to load global skill from {file_path}: {str(e)}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_logger.debug(f'Failed to find global skill files: {str(e)}')
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
class OrgConfig(BaseModel):
|
||||
"""Organization configuration for agent-server API request."""
|
||||
|
||||
repository: str
|
||||
provider: str
|
||||
org_repo_url: str
|
||||
org_name: str
|
||||
def load_sandbox_skills(sandbox: SandboxInfo) -> list[Skill]:
|
||||
"""Load skills specific to the sandbox, including exposed ports / urls."""
|
||||
if not sandbox.exposed_urls:
|
||||
return []
|
||||
urls = [url for url in sandbox.exposed_urls if url.name.startswith('WORKER_')]
|
||||
if not urls:
|
||||
return []
|
||||
content_list = [WORK_HOSTS_SKILL]
|
||||
for url in urls:
|
||||
content_list.append(f'* {url.url} (port {url.port})')
|
||||
content_list.append(WORK_HOSTS_SKILL_FOOTER)
|
||||
content = '\n'.join(content_list)
|
||||
return [Skill(name='work_hosts', content=content, trigger=None)]
|
||||
|
||||
|
||||
class SkillInfo(BaseModel):
|
||||
"""Skill information from agent-server API response."""
|
||||
def load_global_skills() -> list[Skill]:
|
||||
"""Load global skills from OpenHands/skills/ directory.
|
||||
|
||||
name: str
|
||||
content: str
|
||||
triggers: list[str] = []
|
||||
source: str | None = None
|
||||
description: str | None = None
|
||||
is_agentskills_format: bool = False
|
||||
Returns:
|
||||
List of Skill objects loaded from global skills directory.
|
||||
Returns empty list if directory doesn't exist or on errors.
|
||||
"""
|
||||
skill_dir = Path(GLOBAL_SKILLS_DIR)
|
||||
|
||||
# Check if directory exists
|
||||
if not skill_dir.exists():
|
||||
_logger.debug(f'Global skills directory does not exist: {skill_dir}')
|
||||
return []
|
||||
|
||||
try:
|
||||
_logger.info(f'Loading global skills from {skill_dir}')
|
||||
|
||||
# Find and load all .md files from the directory
|
||||
skills = _find_and_load_global_skill_files(skill_dir)
|
||||
|
||||
_logger.info(f'Loaded {len(skills)} global skills: {[s.name for s in skills]}')
|
||||
|
||||
return skills
|
||||
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to load global skills: {str(e)}')
|
||||
return []
|
||||
|
||||
|
||||
def _determine_repo_root(working_dir: str, selected_repository: str | None) -> str:
|
||||
"""Determine the repository root directory.
|
||||
|
||||
Args:
|
||||
working_dir: Base working directory path
|
||||
selected_repository: Repository name (e.g., 'owner/repo') or None
|
||||
|
||||
Returns:
|
||||
Path to the repository root directory
|
||||
"""
|
||||
if selected_repository:
|
||||
repo_name = selected_repository.split('/')[-1]
|
||||
return f'{working_dir}/{repo_name}'
|
||||
return working_dir
|
||||
|
||||
|
||||
async def _is_gitlab_repository(repo_name: str, user_context: UserContext) -> bool:
|
||||
@@ -91,6 +154,8 @@ async def _is_gitlab_repository(repo_name: str, user_context: UserContext) -> bo
|
||||
)
|
||||
return repository.git_provider == ProviderType.GITLAB
|
||||
except Exception:
|
||||
# If we can't determine the provider, assume it's not GitLab
|
||||
# This is a safe fallback since we'll just use the default .openhands
|
||||
return False
|
||||
|
||||
|
||||
@@ -113,33 +178,10 @@ async def _is_azure_devops_repository(
|
||||
)
|
||||
return repository.git_provider == ProviderType.AZURE_DEVOPS
|
||||
except Exception:
|
||||
# If we can't determine the provider, assume it's not Azure DevOps
|
||||
return False
|
||||
|
||||
|
||||
async def _get_provider_type(
|
||||
selected_repository: str, user_context: UserContext
|
||||
) -> str:
|
||||
"""Determine the Git provider type for a repository.
|
||||
|
||||
Args:
|
||||
selected_repository: Repository name (e.g., 'owner/repo')
|
||||
user_context: UserContext to access provider handler
|
||||
|
||||
Returns:
|
||||
Provider type string: 'github', 'gitlab', 'azure', or 'bitbucket'
|
||||
"""
|
||||
is_gitlab = await _is_gitlab_repository(selected_repository, user_context)
|
||||
if is_gitlab:
|
||||
return 'gitlab'
|
||||
|
||||
is_azure = await _is_azure_devops_repository(selected_repository, user_context)
|
||||
if is_azure:
|
||||
return 'azure'
|
||||
|
||||
# Default to github (covers github and bitbucket)
|
||||
return 'github'
|
||||
|
||||
|
||||
async def _determine_org_repo_path(
|
||||
selected_repository: str, user_context: UserContext
|
||||
) -> tuple[str, str]:
|
||||
@@ -161,19 +203,27 @@ async def _determine_org_repo_path(
|
||||
"""
|
||||
repo_parts = selected_repository.split('/')
|
||||
|
||||
# Determine repository type
|
||||
is_azure_devops = await _is_azure_devops_repository(
|
||||
selected_repository, user_context
|
||||
)
|
||||
is_gitlab = await _is_gitlab_repository(selected_repository, user_context)
|
||||
|
||||
# Extract the org/user name
|
||||
# Azure DevOps format: org/project/repo (3 parts) - extract org (first part)
|
||||
# GitHub/GitLab/Bitbucket format: owner/repo (2 parts) - extract owner (first part)
|
||||
if is_azure_devops and len(repo_parts) >= 3:
|
||||
org_name = repo_parts[0]
|
||||
org_name = repo_parts[0] # Get org from org/project/repo
|
||||
else:
|
||||
org_name = repo_parts[-2]
|
||||
org_name = repo_parts[-2] # Get owner from owner/repo
|
||||
|
||||
# For GitLab and Azure DevOps, use openhands-config (since .openhands is not a valid repo name)
|
||||
# For other providers, use .openhands
|
||||
if is_gitlab:
|
||||
org_openhands_repo = f'{org_name}/openhands-config'
|
||||
elif is_azure_devops:
|
||||
# Azure DevOps format: org/project/repo
|
||||
# For org-level config, use: org/openhands-config/openhands-config
|
||||
org_openhands_repo = f'{org_name}/openhands-config/openhands-config'
|
||||
else:
|
||||
org_openhands_repo = f'{org_name}/.openhands'
|
||||
@@ -181,6 +231,227 @@ async def _determine_org_repo_path(
|
||||
return org_openhands_repo, org_name
|
||||
|
||||
|
||||
async def _read_file_from_workspace(
|
||||
workspace: AsyncRemoteWorkspace, file_path: str, working_dir: str
|
||||
) -> str | None:
|
||||
"""Read file content from remote workspace.
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
file_path: Path to the file to read
|
||||
working_dir: Working directory for command execution
|
||||
|
||||
Returns:
|
||||
File content as string, or None if file doesn't exist or read fails
|
||||
"""
|
||||
try:
|
||||
result = await workspace.execute_command(
|
||||
f'cat {file_path}', cwd=working_dir, timeout=10.0
|
||||
)
|
||||
if result.exit_code == 0 and result.stdout.strip():
|
||||
return result.stdout
|
||||
return None
|
||||
except Exception as e:
|
||||
_logger.debug(f'Failed to read file {file_path}: {str(e)}')
|
||||
return None
|
||||
|
||||
|
||||
async def _load_special_files(
|
||||
workspace: AsyncRemoteWorkspace, repo_root: str, working_dir: str
|
||||
) -> list[Skill]:
|
||||
"""Load special skill files from repository root.
|
||||
|
||||
Loads: .cursorrules, agents.md, agent.md
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
repo_root: Path to repository root directory
|
||||
working_dir: Working directory for command execution
|
||||
|
||||
Returns:
|
||||
List of Skill objects loaded from special files
|
||||
"""
|
||||
skills = []
|
||||
special_files = ['.cursorrules', 'agents.md', 'agent.md']
|
||||
|
||||
for filename in special_files:
|
||||
file_path = f'{repo_root}/{filename}'
|
||||
content = await _read_file_from_workspace(workspace, file_path, working_dir)
|
||||
|
||||
if content:
|
||||
try:
|
||||
# Use simple string path to avoid Path filesystem operations
|
||||
skill = Skill.load(path=filename, skill_dir=None, file_content=content)
|
||||
skills.append(skill)
|
||||
_logger.debug(f'Loaded special file skill: {skill.name}')
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to create skill from {filename}: {str(e)}')
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
async def _find_and_load_skill_md_files(
|
||||
workspace: AsyncRemoteWorkspace, skill_dir: str, working_dir: str
|
||||
) -> list[Skill]:
|
||||
"""Find and load all .md files from a skills directory in the workspace.
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
skill_dir: Path to skills directory
|
||||
working_dir: Working directory for command execution
|
||||
|
||||
Returns:
|
||||
List of Skill objects loaded from the files (excluding README.md)
|
||||
"""
|
||||
skills = []
|
||||
|
||||
try:
|
||||
# Find all .md files in the directory
|
||||
result = await workspace.execute_command(
|
||||
f"find {skill_dir} -type f -name '*.md' 2>/dev/null || true",
|
||||
cwd=working_dir,
|
||||
timeout=10.0,
|
||||
)
|
||||
|
||||
if result.exit_code == 0 and result.stdout.strip():
|
||||
file_paths = [
|
||||
f.strip()
|
||||
for f in result.stdout.strip().split('\n')
|
||||
if f.strip() and 'README.md' not in f
|
||||
]
|
||||
|
||||
# Load skills from the found files
|
||||
for file_path in file_paths:
|
||||
content = await _read_file_from_workspace(
|
||||
workspace, file_path, working_dir
|
||||
)
|
||||
|
||||
if content:
|
||||
# Calculate relative path for skill name
|
||||
rel_path = file_path.replace(f'{skill_dir}/', '')
|
||||
try:
|
||||
# Use simple string path to avoid Path filesystem operations
|
||||
skill = Skill.load(
|
||||
path=rel_path, skill_dir=None, file_content=content
|
||||
)
|
||||
skills.append(skill)
|
||||
_logger.debug(f'Loaded repo skill: {skill.name}')
|
||||
except Exception as e:
|
||||
_logger.warning(
|
||||
f'Failed to create skill from {rel_path}: {str(e)}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_logger.debug(f'Failed to find skill files in {skill_dir}: {str(e)}')
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
def _merge_repo_skills_with_precedence(
|
||||
special_skills: list[Skill],
|
||||
skills_dir_skills: list[Skill],
|
||||
microagents_dir_skills: list[Skill],
|
||||
) -> list[Skill]:
|
||||
"""Merge repository skills with precedence order.
|
||||
|
||||
Precedence (highest to lowest):
|
||||
1. Special files (repo root)
|
||||
2. .openhands/skills/ directory
|
||||
3. .openhands/microagents/ directory (backward compatibility)
|
||||
|
||||
Args:
|
||||
special_skills: Skills from special files in repo root
|
||||
skills_dir_skills: Skills from .openhands/skills/ directory
|
||||
microagents_dir_skills: Skills from .openhands/microagents/ directory
|
||||
|
||||
Returns:
|
||||
Deduplicated list of skills with proper precedence
|
||||
"""
|
||||
# Use a dict to deduplicate by name, with earlier sources taking precedence
|
||||
skills_by_name = {}
|
||||
for skill in special_skills + skills_dir_skills + microagents_dir_skills:
|
||||
# Only add if not already present (earlier sources win)
|
||||
if skill.name not in skills_by_name:
|
||||
skills_by_name[skill.name] = skill
|
||||
|
||||
return list(skills_by_name.values())
|
||||
|
||||
|
||||
async def load_repo_skills(
|
||||
workspace: AsyncRemoteWorkspace,
|
||||
selected_repository: str | None,
|
||||
working_dir: str,
|
||||
) -> list[Skill]:
|
||||
"""Load repository-level skills from the workspace.
|
||||
|
||||
Loads skills from:
|
||||
1. Special files in repo root: .cursorrules, agents.md, agent.md
|
||||
2. .md files in .openhands/skills/ directory (preferred)
|
||||
3. .md files in .openhands/microagents/ directory (for backward compatibility)
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands in the sandbox
|
||||
selected_repository: Repository name (e.g., 'owner/repo') or None
|
||||
working_dir: Working directory path
|
||||
|
||||
Returns:
|
||||
List of Skill objects loaded from repository.
|
||||
Returns empty list on errors.
|
||||
"""
|
||||
try:
|
||||
# Determine repository root directory
|
||||
repo_root = _determine_repo_root(working_dir, selected_repository)
|
||||
_logger.info(f'Loading repo skills from {repo_root}')
|
||||
|
||||
# Load special files from repo root
|
||||
special_skills = await _load_special_files(workspace, repo_root, working_dir)
|
||||
|
||||
# Load .md files from .openhands/skills/ directory (preferred)
|
||||
skills_dir = f'{repo_root}/.openhands/skills'
|
||||
skills_dir_skills = await _find_and_load_skill_md_files(
|
||||
workspace, skills_dir, working_dir
|
||||
)
|
||||
|
||||
# Load .md files from .openhands/microagents/ directory (backward compatibility)
|
||||
microagents_dir = f'{repo_root}/.openhands/microagents'
|
||||
microagents_dir_skills = await _find_and_load_skill_md_files(
|
||||
workspace, microagents_dir, working_dir
|
||||
)
|
||||
|
||||
# Merge all loaded skills with proper precedence
|
||||
all_skills = _merge_repo_skills_with_precedence(
|
||||
special_skills, skills_dir_skills, microagents_dir_skills
|
||||
)
|
||||
|
||||
_logger.info(
|
||||
f'Loaded {len(all_skills)} repo skills: {[s.name for s in all_skills]}'
|
||||
)
|
||||
|
||||
return all_skills
|
||||
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to load repo skills: {str(e)}')
|
||||
return []
|
||||
|
||||
|
||||
def _validate_repository_for_org_skills(selected_repository: str) -> bool:
|
||||
"""Validate that the repository path has sufficient parts for org skills.
|
||||
|
||||
Args:
|
||||
selected_repository: Repository name (e.g., 'owner/repo')
|
||||
|
||||
Returns:
|
||||
True if repository is valid for org skills loading, False otherwise
|
||||
"""
|
||||
repo_parts = selected_repository.split('/')
|
||||
if len(repo_parts) < 2:
|
||||
_logger.warning(
|
||||
f'Repository path has insufficient parts ({len(repo_parts)} < 2), skipping org-level skills'
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def _get_org_repository_url(
|
||||
org_openhands_repo: str, user_context: UserContext
|
||||
) -> str | None:
|
||||
@@ -210,193 +481,224 @@ async def _get_org_repository_url(
|
||||
return None
|
||||
|
||||
|
||||
async def build_org_config(
|
||||
selected_repository: str | None,
|
||||
user_context: UserContext,
|
||||
) -> OrgConfig | None:
|
||||
"""Build organization config for agent-server API request.
|
||||
async def _clone_org_repository(
|
||||
workspace: AsyncRemoteWorkspace,
|
||||
remote_url: str,
|
||||
org_repo_dir: str,
|
||||
working_dir: str,
|
||||
org_openhands_repo: str,
|
||||
) -> bool:
|
||||
"""Clone organization repository to temporary directory.
|
||||
|
||||
Args:
|
||||
selected_repository: Repository name (e.g., 'owner/repo') or None
|
||||
user_context: UserContext to access authentication and provider info
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
remote_url: Authenticated Git URL
|
||||
org_repo_dir: Temporary directory path for cloning
|
||||
working_dir: Working directory for command execution
|
||||
org_openhands_repo: Organization repository path (for logging)
|
||||
|
||||
Returns:
|
||||
org_config dict if org repository exists and is accessible, None otherwise
|
||||
True if clone successful, False otherwise
|
||||
"""
|
||||
_logger.debug(f'Creating temporary directory for org repo: {org_repo_dir}')
|
||||
|
||||
# Clone the repo (shallow clone for efficiency)
|
||||
clone_cmd = f'GIT_TERMINAL_PROMPT=0 git clone --depth 1 {remote_url} {org_repo_dir}'
|
||||
_logger.info('Executing clone command for org-level repo')
|
||||
|
||||
result = await workspace.execute_command(clone_cmd, working_dir, timeout=120.0)
|
||||
|
||||
if result.exit_code != 0:
|
||||
_logger.info(
|
||||
f'No org-level skills found at {org_openhands_repo} (exit_code: {result.exit_code})'
|
||||
)
|
||||
_logger.debug(f'Clone command output: {result.stderr}')
|
||||
return False
|
||||
|
||||
_logger.info(f'Successfully cloned org-level skills from {org_openhands_repo}')
|
||||
return True
|
||||
|
||||
|
||||
async def _load_skills_from_org_directories(
|
||||
workspace: AsyncRemoteWorkspace, org_repo_dir: str, working_dir: str
|
||||
) -> tuple[list[Skill], list[Skill]]:
|
||||
"""Load skills from both skills/ and microagents/ directories in org repo.
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
org_repo_dir: Path to cloned organization repository
|
||||
working_dir: Working directory for command execution
|
||||
|
||||
Returns:
|
||||
Tuple of (skills_dir_skills, microagents_dir_skills)
|
||||
"""
|
||||
skills_dir = f'{org_repo_dir}/skills'
|
||||
skills_dir_skills = await _find_and_load_skill_md_files(
|
||||
workspace, skills_dir, working_dir
|
||||
)
|
||||
|
||||
microagents_dir = f'{org_repo_dir}/microagents'
|
||||
microagents_dir_skills = await _find_and_load_skill_md_files(
|
||||
workspace, microagents_dir, working_dir
|
||||
)
|
||||
|
||||
return skills_dir_skills, microagents_dir_skills
|
||||
|
||||
|
||||
def _merge_org_skills_with_precedence(
|
||||
skills_dir_skills: list[Skill], microagents_dir_skills: list[Skill]
|
||||
) -> list[Skill]:
|
||||
"""Merge skills from skills/ and microagents/ with proper precedence.
|
||||
|
||||
Precedence: skills/ > microagents/ (skills/ overrides microagents/ for same name)
|
||||
|
||||
Args:
|
||||
skills_dir_skills: Skills loaded from skills/ directory
|
||||
microagents_dir_skills: Skills loaded from microagents/ directory
|
||||
|
||||
Returns:
|
||||
Merged list of skills with proper precedence applied
|
||||
"""
|
||||
skills_by_name = {}
|
||||
for skill in microagents_dir_skills + skills_dir_skills:
|
||||
# Later sources (skills/) override earlier ones (microagents/)
|
||||
if skill.name not in skills_by_name:
|
||||
skills_by_name[skill.name] = skill
|
||||
else:
|
||||
_logger.debug(
|
||||
f'Overriding org skill "{skill.name}" from microagents/ with skills/'
|
||||
)
|
||||
skills_by_name[skill.name] = skill
|
||||
|
||||
return list(skills_by_name.values())
|
||||
|
||||
|
||||
async def _cleanup_org_repository(
|
||||
workspace: AsyncRemoteWorkspace, org_repo_dir: str, working_dir: str
|
||||
) -> None:
|
||||
"""Clean up cloned organization repository directory.
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands
|
||||
org_repo_dir: Path to cloned organization repository
|
||||
working_dir: Working directory for command execution
|
||||
"""
|
||||
cleanup_cmd = f'rm -rf {org_repo_dir}'
|
||||
await workspace.execute_command(cleanup_cmd, working_dir, timeout=10.0)
|
||||
|
||||
|
||||
async def load_org_skills(
|
||||
workspace: AsyncRemoteWorkspace,
|
||||
selected_repository: str | None,
|
||||
working_dir: str,
|
||||
user_context: UserContext,
|
||||
) -> list[Skill]:
|
||||
"""Load organization-level skills from the organization repository.
|
||||
|
||||
For example, if the repository is github.com/acme-co/api, this will check if
|
||||
github.com/acme-co/.openhands exists. If it does, it will clone it and load
|
||||
the skills from both the ./skills/ and ./microagents/ folders.
|
||||
|
||||
For GitLab repositories, it will use openhands-config instead of .openhands
|
||||
since GitLab doesn't support repository names starting with non-alphanumeric
|
||||
characters.
|
||||
|
||||
For Azure DevOps repositories, it will use org/openhands-config/openhands-config
|
||||
format to match Azure DevOps's three-part repository structure (org/project/repo).
|
||||
|
||||
Args:
|
||||
workspace: AsyncRemoteWorkspace to execute commands in the sandbox
|
||||
selected_repository: Repository name (e.g., 'owner/repo') or None
|
||||
working_dir: Working directory path
|
||||
user_context: UserContext to access provider handler and authentication
|
||||
|
||||
Returns:
|
||||
List of Skill objects loaded from organization repository.
|
||||
Returns empty list if no repository selected or on errors.
|
||||
"""
|
||||
if not selected_repository:
|
||||
return None
|
||||
|
||||
repo_parts = selected_repository.split('/')
|
||||
if len(repo_parts) < 2:
|
||||
_logger.warning(
|
||||
f'Repository path has insufficient parts ({len(repo_parts)} < 2), '
|
||||
f'skipping org-level skills'
|
||||
)
|
||||
return None
|
||||
return []
|
||||
|
||||
try:
|
||||
_logger.debug(
|
||||
f'Starting org-level skill loading for repository: {selected_repository}'
|
||||
)
|
||||
|
||||
# Validate repository path
|
||||
if not _validate_repository_for_org_skills(selected_repository):
|
||||
return []
|
||||
|
||||
# Determine organization repository path
|
||||
org_openhands_repo, org_name = await _determine_org_repo_path(
|
||||
selected_repository, user_context
|
||||
)
|
||||
|
||||
org_repo_url = await _get_org_repository_url(org_openhands_repo, user_context)
|
||||
if not org_repo_url:
|
||||
return None
|
||||
_logger.info(f'Checking for org-level skills at {org_openhands_repo}')
|
||||
|
||||
provider = await _get_provider_type(selected_repository, user_context)
|
||||
# Get authenticated URL for org repository
|
||||
remote_url = await _get_org_repository_url(org_openhands_repo, user_context)
|
||||
if not remote_url:
|
||||
return []
|
||||
|
||||
return OrgConfig(
|
||||
repository=selected_repository,
|
||||
provider=provider,
|
||||
org_repo_url=org_repo_url,
|
||||
org_name=org_name,
|
||||
# Clone the organization repository
|
||||
org_repo_dir = f'{working_dir}/_org_openhands_{org_name}'
|
||||
clone_success = await _clone_org_repository(
|
||||
workspace, remote_url, org_repo_dir, working_dir, org_openhands_repo
|
||||
)
|
||||
if not clone_success:
|
||||
return []
|
||||
|
||||
# Load skills from both skills/ and microagents/ directories
|
||||
(
|
||||
skills_dir_skills,
|
||||
microagents_dir_skills,
|
||||
) = await _load_skills_from_org_directories(
|
||||
workspace, org_repo_dir, working_dir
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_logger.debug(f'Failed to build org config: {str(e)}')
|
||||
return None
|
||||
# Merge skills with proper precedence
|
||||
loaded_skills = _merge_org_skills_with_precedence(
|
||||
skills_dir_skills, microagents_dir_skills
|
||||
)
|
||||
|
||||
|
||||
def build_sandbox_config(sandbox: SandboxInfo) -> SandboxConfig | None:
|
||||
"""Build sandbox config for agent-server API request.
|
||||
|
||||
Args:
|
||||
sandbox: SandboxInfo containing exposed URLs
|
||||
|
||||
Returns:
|
||||
sandbox_config dict if there are exposed URLs, None otherwise
|
||||
"""
|
||||
if not sandbox.exposed_urls:
|
||||
return None
|
||||
|
||||
exposed_urls = [
|
||||
ExposedUrlConfig(name=url.name, url=url.url, port=url.port)
|
||||
for url in sandbox.exposed_urls
|
||||
]
|
||||
|
||||
return SandboxConfig(exposed_urls=exposed_urls)
|
||||
|
||||
|
||||
async def load_skills_from_agent_server(
|
||||
agent_server_url: str,
|
||||
session_api_key: str | None,
|
||||
project_dir: str,
|
||||
org_config: OrgConfig | None = None,
|
||||
sandbox_config: SandboxConfig | None = None,
|
||||
load_public: bool = True,
|
||||
load_user: bool = True,
|
||||
load_project: bool = True,
|
||||
load_org: bool = True,
|
||||
) -> list[Skill]:
|
||||
"""Load all skills from the agent-server.
|
||||
|
||||
This function makes a single API call to the agent-server's /api/skills
|
||||
endpoint to load and merge skills from all configured sources.
|
||||
|
||||
Args:
|
||||
agent_server_url: URL of the agent server (e.g., 'http://localhost:8000')
|
||||
session_api_key: Session API key for authentication (optional)
|
||||
project_dir: Workspace directory path for project skills
|
||||
org_config: Organization skills configuration (optional)
|
||||
sandbox_config: Sandbox skills configuration (optional)
|
||||
load_public: Whether to load public skills (default: True)
|
||||
load_user: Whether to load user skills (default: True)
|
||||
load_project: Whether to load project skills (default: True)
|
||||
load_org: Whether to load organization skills (default: True)
|
||||
|
||||
Returns:
|
||||
List of Skill objects merged from all sources.
|
||||
Returns empty list on error.
|
||||
"""
|
||||
try:
|
||||
# Build request payload
|
||||
payload = {
|
||||
'load_public': load_public,
|
||||
'load_user': load_user,
|
||||
'load_project': load_project,
|
||||
'load_org': load_org,
|
||||
'project_dir': project_dir,
|
||||
'org_config': org_config.model_dump() if org_config else None,
|
||||
'sandbox_config': sandbox_config.model_dump() if sandbox_config else None,
|
||||
}
|
||||
|
||||
# Build headers
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
if session_api_key:
|
||||
headers['X-Session-API-Key'] = session_api_key
|
||||
|
||||
# Make API request
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f'{agent_server_url}/api/skills',
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=60.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Convert response to Skill objects
|
||||
skills: list[Skill] = []
|
||||
for skill_data_dict in data.get('skills', []):
|
||||
try:
|
||||
skill_info = SkillInfo.model_validate(skill_data_dict)
|
||||
skill = _convert_skill_info_to_skill(skill_info)
|
||||
skills.append(skill)
|
||||
except Exception as e:
|
||||
skill_name = (
|
||||
skill_data_dict.get('name', 'unknown')
|
||||
if isinstance(skill_data_dict, dict)
|
||||
else 'unknown'
|
||||
)
|
||||
_logger.warning(f'Failed to convert skill {skill_name}: {e}')
|
||||
|
||||
sources = data.get('sources', {})
|
||||
_logger.info(
|
||||
f'Loaded {len(skills)} skills from agent-server: '
|
||||
f'sources={sources}, names={[s.name for s in skills]}'
|
||||
f'Loaded {len(loaded_skills)} skills from org-level repository {org_openhands_repo}: {[s.name for s in loaded_skills]}'
|
||||
)
|
||||
|
||||
return skills
|
||||
# Clean up the org repo directory
|
||||
await _cleanup_org_repository(workspace, org_repo_dir, working_dir)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
_logger.warning(
|
||||
f'Agent-server returned error status {e.response.status_code}: '
|
||||
f'{e.response.text}'
|
||||
)
|
||||
return []
|
||||
except httpx.RequestError as e:
|
||||
_logger.warning(f'Failed to connect to agent-server: {e}')
|
||||
return loaded_skills
|
||||
|
||||
except AuthenticationError as e:
|
||||
_logger.debug(f'org-level skill directory not found: {str(e)}')
|
||||
return []
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to load skills from agent-server: {e}')
|
||||
_logger.warning(f'Failed to load org-level skills: {str(e)}')
|
||||
return []
|
||||
|
||||
|
||||
def _convert_skill_info_to_skill(skill_info: SkillInfo) -> Skill:
|
||||
"""Convert skill info from API response to Skill object.
|
||||
def merge_skills(skill_lists: list[list[Skill]]) -> list[Skill]:
|
||||
"""Merge multiple skill lists, avoiding duplicates by name.
|
||||
|
||||
Later lists take precedence over earlier lists for duplicate names.
|
||||
|
||||
Args:
|
||||
skill_info: SkillInfo model from API response
|
||||
skill_lists: List of skill lists to merge
|
||||
|
||||
Returns:
|
||||
Skill object
|
||||
Deduplicated list of skills with later lists overriding earlier ones
|
||||
"""
|
||||
trigger = None
|
||||
skills_by_name = {}
|
||||
|
||||
if skill_info.triggers:
|
||||
# Determine trigger type based on content
|
||||
if any(t.startswith('/') for t in skill_info.triggers):
|
||||
trigger = TaskTrigger(triggers=skill_info.triggers)
|
||||
else:
|
||||
trigger = KeywordTrigger(keywords=skill_info.triggers)
|
||||
for skill_list in skill_lists:
|
||||
for skill in skill_list:
|
||||
if skill.name in skills_by_name:
|
||||
_logger.debug(
|
||||
f'Overriding skill "{skill.name}" from earlier source with later source'
|
||||
)
|
||||
skills_by_name[skill.name] = skill
|
||||
|
||||
return Skill(
|
||||
name=skill_info.name,
|
||||
content=skill_info.content,
|
||||
trigger=trigger,
|
||||
source=skill_info.source,
|
||||
description=skill_info.description,
|
||||
is_agentskills_format=skill_info.is_agentskills_format,
|
||||
)
|
||||
result = list(skills_by_name.values())
|
||||
_logger.debug(f'Merged skills: {[s.name for s in result]}')
|
||||
return result
|
||||
|
||||
@@ -13,7 +13,7 @@ from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
# The version of the agent server to use for deployments.
|
||||
# Typically this will be the same as the values from the pyproject.toml
|
||||
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:31536c8-python'
|
||||
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:0fdea73-python'
|
||||
|
||||
|
||||
class SandboxSpecService(ABC):
|
||||
|
||||
12
openhands/architecture/README.md
Normal file
12
openhands/architecture/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# OpenHands Architecture
|
||||
|
||||
This document provides detailed architecture diagrams and explanations for the OpenHands system.
|
||||
|
||||
## Documentation Sections
|
||||
|
||||
- [System Architecture Overview](./system-architecture.md)
|
||||
- [Conversation Startup & WebSocket Flow](./conversation-startup.md)
|
||||
- [Authentication Flow](./authentication.md)
|
||||
- [Agent Execution & LLM Flow](./agent-execution.md)
|
||||
- [External Integrations](./external-integrations.md)
|
||||
- [Metrics, Logs & Observability](./observability.md)
|
||||
96
openhands/architecture/agent-execution.md
Normal file
96
openhands/architecture/agent-execution.md
Normal file
@@ -0,0 +1,96 @@
|
||||
# Agent Execution & LLM Flow
|
||||
|
||||
When the agent executes inside the sandbox, it makes LLM calls through LiteLLM:
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
participant User as User (Browser)
|
||||
participant AS as Agent Server
|
||||
participant Agent as Agent<br/>(CodeAct)
|
||||
participant LLM as LLM Class
|
||||
participant Lite as LiteLLM
|
||||
participant Proxy as LLM Proxy<br/>(llm-proxy.app.all-hands.dev)
|
||||
participant Provider as LLM Provider<br/>(OpenAI, Anthropic, etc.)
|
||||
participant AES as Action Execution Server
|
||||
|
||||
Note over User,AES: Agent Loop - LLM Call Flow
|
||||
|
||||
User->>AS: WebSocket: User message
|
||||
AS->>Agent: Process message
|
||||
Agent->>Agent: Build prompt from state
|
||||
|
||||
Agent->>LLM: completion(messages, tools)
|
||||
LLM->>LLM: Apply config (model, temp, etc.)
|
||||
|
||||
alt Using OpenHands Provider
|
||||
LLM->>Lite: litellm_proxy/{model}
|
||||
Lite->>Proxy: POST /chat/completions
|
||||
Proxy->>Proxy: Auth, rate limit, routing
|
||||
Proxy->>Provider: Forward request
|
||||
Provider-->>Proxy: Response
|
||||
Proxy-->>Lite: Response
|
||||
else Using Direct Provider
|
||||
LLM->>Lite: {provider}/{model}
|
||||
Lite->>Provider: Direct API call
|
||||
Provider-->>Lite: Response
|
||||
end
|
||||
|
||||
Lite-->>LLM: ModelResponse
|
||||
LLM->>LLM: Track metrics (cost, tokens)
|
||||
LLM-->>Agent: Parsed response
|
||||
|
||||
Agent->>Agent: Parse action from response
|
||||
AS->>User: WebSocket: Action event
|
||||
|
||||
Note over User,AES: Action Execution
|
||||
|
||||
AS->>AES: HTTP: Execute action
|
||||
AES->>AES: Run command/edit file
|
||||
AES-->>AS: Observation
|
||||
AS->>User: WebSocket: Observation event
|
||||
|
||||
Agent->>Agent: Update state
|
||||
Note over Agent: Loop continues...
|
||||
```
|
||||
|
||||
### LLM Components
|
||||
|
||||
| Component | Purpose | Location |
|
||||
|-----------|---------|----------|
|
||||
| **LLM Class** | Wrapper with retries, metrics, config | `openhands/llm/llm.py` |
|
||||
| **LiteLLM** | Universal LLM API adapter | External library |
|
||||
| **LLM Proxy** | OpenHands managed proxy for billing/routing | `llm-proxy.app.all-hands.dev` |
|
||||
| **LLM Registry** | Manages multiple LLM instances | `openhands/llm/llm_registry.py` |
|
||||
|
||||
### Model Routing
|
||||
|
||||
```
|
||||
User selects model
|
||||
│
|
||||
▼
|
||||
┌───────────────────┐
|
||||
│ Model prefix? │
|
||||
└───────────────────┘
|
||||
│
|
||||
├── openhands/claude-3-5 ──► Rewrite to litellm_proxy/claude-3-5
|
||||
│ Base URL: llm-proxy.app.all-hands.dev
|
||||
│
|
||||
├── anthropic/claude-3-5 ──► Direct to Anthropic API
|
||||
│ (User's API key)
|
||||
│
|
||||
├── openai/gpt-4 ──► Direct to OpenAI API
|
||||
│ (User's API key)
|
||||
│
|
||||
└── azure/gpt-4 ──► Direct to Azure OpenAI
|
||||
(User's API key + endpoint)
|
||||
```
|
||||
|
||||
### LLM Proxy Benefits
|
||||
|
||||
When using `openhands/` prefixed models:
|
||||
- **Unified Billing**: Costs tracked through OpenHands account
|
||||
- **No API Keys Needed**: Users don't need their own provider keys
|
||||
- **Rate Limiting**: Managed quotas and throttling
|
||||
- **Model Routing**: Automatic failover and load balancing
|
||||
- **Usage Tracking**: Detailed metrics and cost analysis
|
||||
58
openhands/architecture/authentication.md
Normal file
58
openhands/architecture/authentication.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Authentication Flow
|
||||
|
||||
OpenHands uses Keycloak for identity management in the SaaS deployment. The authentication flow involves multiple services:
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
participant User as User (Browser)
|
||||
participant App as App Server
|
||||
participant KC as Keycloak
|
||||
participant IdP as Identity Provider<br/>(GitHub, Google, etc.)
|
||||
participant DB as User Database
|
||||
|
||||
Note over User,DB: OAuth 2.0 / OIDC Authentication Flow
|
||||
|
||||
User->>App: Access OpenHands
|
||||
App->>User: Redirect to Keycloak
|
||||
User->>KC: Login request
|
||||
KC->>User: Show login options
|
||||
User->>KC: Select provider (e.g., GitHub)
|
||||
KC->>IdP: OAuth redirect
|
||||
User->>IdP: Authenticate
|
||||
IdP-->>KC: OAuth callback + tokens
|
||||
KC->>KC: Create/update user session
|
||||
KC-->>User: Redirect with auth code
|
||||
User->>App: Auth code
|
||||
App->>KC: Exchange code for tokens
|
||||
KC-->>App: Access token + Refresh token
|
||||
App->>App: Create signed JWT cookie
|
||||
App->>DB: Store/update user record
|
||||
App-->>User: Set keycloak_auth cookie
|
||||
|
||||
Note over User,DB: Subsequent Requests
|
||||
|
||||
User->>App: Request with cookie
|
||||
App->>App: Verify JWT signature
|
||||
App->>KC: Validate token (if needed)
|
||||
KC-->>App: Token valid
|
||||
App->>App: Extract user context
|
||||
App-->>User: Authorized response
|
||||
```
|
||||
|
||||
### Authentication Components
|
||||
|
||||
| Component | Purpose | Location |
|
||||
|-----------|---------|----------|
|
||||
| **Keycloak** | Identity provider, SSO, token management | External service |
|
||||
| **UserAuth** | Abstract auth interface | `openhands/server/user_auth/user_auth.py` |
|
||||
| **SaasUserAuth** | Keycloak implementation | `enterprise/server/auth/saas_user_auth.py` |
|
||||
| **JWT Service** | Token signing/verification | `openhands/app_server/services/jwt_service.py` |
|
||||
| **Auth Routes** | Login/logout endpoints | `enterprise/server/routes/auth.py` |
|
||||
|
||||
### Token Flow
|
||||
|
||||
1. **Keycloak Access Token**: Short-lived token for API access
|
||||
2. **Keycloak Refresh Token**: Long-lived token to obtain new access tokens
|
||||
3. **Signed JWT Cookie**: App Server's session cookie containing encrypted Keycloak tokens
|
||||
4. **Provider Tokens**: OAuth tokens for GitHub, GitLab, etc. (stored separately for git operations)
|
||||
68
openhands/architecture/conversation-startup.md
Normal file
68
openhands/architecture/conversation-startup.md
Normal file
@@ -0,0 +1,68 @@
|
||||
# Conversation Startup & WebSocket Flow
|
||||
|
||||
When a user starts a conversation, this sequence occurs:
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
participant User as User (Browser)
|
||||
participant App as App Server
|
||||
participant SS as Sandbox Service
|
||||
participant RAPI as Runtime API
|
||||
participant Pool as Warm Pool
|
||||
participant Sandbox as Sandbox (Container)
|
||||
participant AS as Agent Server
|
||||
participant AES as Action Execution Server
|
||||
|
||||
Note over User,AES: Phase 1: Conversation Creation
|
||||
User->>App: POST /api/conversations
|
||||
App->>App: Authenticate user
|
||||
App->>SS: Create sandbox
|
||||
|
||||
Note over SS,Pool: Phase 2: Runtime Provisioning
|
||||
SS->>RAPI: POST /start (image, env, config)
|
||||
RAPI->>Pool: Check for warm runtime
|
||||
alt Warm runtime available
|
||||
Pool-->>RAPI: Return warm runtime
|
||||
RAPI->>RAPI: Assign to session
|
||||
else No warm runtime
|
||||
RAPI->>Sandbox: Create new container
|
||||
Sandbox->>AS: Start Agent Server
|
||||
Sandbox->>AES: Start Action Execution Server
|
||||
AES-->>AS: Ready
|
||||
end
|
||||
RAPI-->>SS: Runtime URL + session API key
|
||||
SS-->>App: Sandbox info
|
||||
App-->>User: Conversation ID + Sandbox URL
|
||||
|
||||
Note over User,AES: Phase 3: Direct WebSocket Connection
|
||||
User->>AS: WebSocket: /sockets/events/{id}
|
||||
AS-->>User: Connection accepted
|
||||
AS->>User: Replay historical events
|
||||
|
||||
Note over User,AES: Phase 4: User Sends Message
|
||||
User->>AS: WebSocket: SendMessageRequest
|
||||
AS->>AS: Agent processes message
|
||||
AS->>AS: LLM call → generate action
|
||||
|
||||
Note over User,AES: Phase 5: Action Execution Loop
|
||||
loop Agent Loop
|
||||
AS->>AES: HTTP: Execute action
|
||||
AES->>AES: Run in sandbox
|
||||
AES-->>AS: Observation result
|
||||
AS->>User: WebSocket: Event update
|
||||
AS->>AS: Update state, next action
|
||||
end
|
||||
|
||||
Note over User,AES: Phase 6: Task Complete
|
||||
AS->>User: WebSocket: AgentStateChanged (FINISHED)
|
||||
```
|
||||
|
||||
### Key Points
|
||||
|
||||
1. **Initial Setup via App Server**: The App Server handles authentication and coordinates with the Sandbox Service
|
||||
2. **Runtime API Provisioning**: The Sandbox Service calls the Runtime API, which checks for warm runtimes before creating new containers
|
||||
3. **Warm Pool Optimization**: Pre-warmed runtimes reduce startup latency significantly
|
||||
4. **Direct WebSocket to Sandbox**: Once created, the user's browser connects **directly** to the Agent Server inside the sandbox
|
||||
5. **App Server Not in Hot Path**: After connection, all real-time communication bypasses the App Server entirely
|
||||
6. **Agent Server Orchestrates**: The Agent Server manages the AI loop, calling the Action Execution Server for actual command execution
|
||||
88
openhands/architecture/external-integrations.md
Normal file
88
openhands/architecture/external-integrations.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# External Integrations
|
||||
|
||||
OpenHands integrates with external services (GitHub, Slack, Jira, etc.) through webhook-based event handling:
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
autonumber
|
||||
participant Ext as External Service<br/>(GitHub/Slack/Jira)
|
||||
participant App as App Server
|
||||
participant IntRouter as Integration Router
|
||||
participant Manager as Integration Manager
|
||||
participant Conv as Conversation Service
|
||||
participant Sandbox as Sandbox
|
||||
|
||||
Note over Ext,Sandbox: Webhook Event Flow (e.g., GitHub Issue Created)
|
||||
|
||||
Ext->>App: POST /api/integration/{service}/events
|
||||
App->>IntRouter: Route to service handler
|
||||
IntRouter->>IntRouter: Verify signature<br/>(HMAC/signing secret)
|
||||
|
||||
IntRouter->>Manager: Parse event payload
|
||||
Manager->>Manager: Extract context<br/>(repo, issue, user)
|
||||
Manager->>Manager: Map external user → OpenHands user<br/>(via stored tokens)
|
||||
|
||||
Manager->>Conv: Create conversation<br/>(with issue context)
|
||||
Conv->>Sandbox: Provision sandbox
|
||||
Sandbox-->>Conv: Ready
|
||||
|
||||
Manager->>Sandbox: Start agent with task
|
||||
|
||||
Note over Ext,Sandbox: Agent Works on Task...
|
||||
|
||||
Sandbox-->>Manager: Task complete
|
||||
Manager->>Ext: POST result<br/>(PR, comment, etc.)
|
||||
|
||||
Note over Ext,Sandbox: Callback Flow (Agent → External Service)
|
||||
|
||||
Sandbox->>App: Webhook callback<br/>/api/v1/webhooks
|
||||
App->>Manager: Process callback
|
||||
Manager->>Ext: Update external service
|
||||
```
|
||||
|
||||
### Supported Integrations
|
||||
|
||||
| Integration | Trigger Events | Agent Actions |
|
||||
|-------------|----------------|---------------|
|
||||
| **GitHub** | Issue created, PR opened, @mention | Create PR, comment, push commits |
|
||||
| **GitLab** | Issue created, MR opened | Create MR, comment, push commits |
|
||||
| **Slack** | @mention in channel | Reply in thread, create tasks |
|
||||
| **Jira** | Issue created/updated | Update ticket, add comments |
|
||||
| **Linear** | Issue created | Update status, add comments |
|
||||
|
||||
### Integration Components
|
||||
|
||||
| Component | Purpose | Location |
|
||||
|-----------|---------|----------|
|
||||
| **Integration Routes** | Webhook endpoints per service | `enterprise/server/routes/integration/` |
|
||||
| **Integration Managers** | Business logic per service | `enterprise/integrations/{service}/` |
|
||||
| **Token Manager** | Store/retrieve OAuth tokens | `enterprise/server/auth/token_manager.py` |
|
||||
| **Callback Processor** | Handle agent → service updates | `enterprise/integrations/{service}/*_callback_processor.py` |
|
||||
|
||||
### Integration Authentication
|
||||
|
||||
```
|
||||
External Service (e.g., GitHub)
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────┐
|
||||
│ GitHub App Installation │
|
||||
│ - Webhook secret for signature │
|
||||
│ - App private key for API calls │
|
||||
└─────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────┐
|
||||
│ User Account Linking │
|
||||
│ - Keycloak user ID │
|
||||
│ - GitHub user ID │
|
||||
│ - Stored OAuth tokens │
|
||||
└─────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────┐
|
||||
│ Agent Execution │
|
||||
│ - Uses linked tokens for API │
|
||||
│ - Can push, create PRs, comment │
|
||||
└─────────────────────────────────┘
|
||||
```
|
||||
103
openhands/architecture/observability.md
Normal file
103
openhands/architecture/observability.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# Metrics, Logs & Observability
|
||||
|
||||
OpenHands uses multiple systems for monitoring, analytics, and debugging:
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph Sources["Sources"]
|
||||
Agent["Agent Server"]
|
||||
App["App Server"]
|
||||
Frontend["Frontend"]
|
||||
end
|
||||
|
||||
subgraph Collection["Collection"]
|
||||
JSONLog["JSON Logs"]
|
||||
Metrics["Metrics"]
|
||||
PH["PostHog"]
|
||||
end
|
||||
|
||||
subgraph Services["Services"]
|
||||
DD["DataDog"]
|
||||
PHCloud["PostHog Cloud"]
|
||||
end
|
||||
|
||||
Agent --> JSONLog
|
||||
App --> JSONLog
|
||||
App --> PH
|
||||
Frontend --> PH
|
||||
|
||||
JSONLog --> DD
|
||||
Metrics --> DD
|
||||
PH --> PHCloud
|
||||
```
|
||||
|
||||
### Logging Infrastructure
|
||||
|
||||
| Component | Format | Destination | Purpose |
|
||||
|-----------|--------|-------------|---------|
|
||||
| **Application Logs** | JSON (when `LOG_JSON=1`) | stdout → DataDog | Debugging, error tracking |
|
||||
| **Access Logs** | JSON (Uvicorn) | stdout → DataDog | Request tracing |
|
||||
| **LLM Debug Logs** | Plain text | File (optional) | LLM call debugging |
|
||||
|
||||
### JSON Log Format
|
||||
|
||||
When `LOG_JSON=1` is set, all logs are emitted as single-line JSON for DataDog ingestion:
|
||||
|
||||
```json
|
||||
{
|
||||
"message": "Conversation started",
|
||||
"severity": "INFO",
|
||||
"conversation_id": "abc-123",
|
||||
"user_id": "user-456",
|
||||
"timestamp": "2024-01-15T10:30:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
### Metrics Tracked
|
||||
|
||||
| Metric | Tracked By | Storage | Purpose |
|
||||
|--------|------------|---------|---------|
|
||||
| **LLM Cost** | `Metrics` class | Conversation stats file | Billing, budget limits |
|
||||
| **Token Usage** | `Metrics` class | Conversation stats file | Usage analytics |
|
||||
| **Response Latency** | `Metrics` class | Conversation stats file | Performance monitoring |
|
||||
| **User Events** | PostHog | PostHog Cloud | Product analytics |
|
||||
| **Feature Flags** | PostHog | PostHog Cloud | Gradual rollouts |
|
||||
|
||||
### PostHog Analytics
|
||||
|
||||
PostHog is used for both product analytics and feature flags:
|
||||
|
||||
**Frontend Events:**
|
||||
- `conversation_started`
|
||||
- `download_trajectory_button_clicked`
|
||||
- Feature flag checks
|
||||
|
||||
**Backend Events:**
|
||||
- Experiment assignments
|
||||
- Conversion tracking
|
||||
|
||||
### DataDog Integration
|
||||
|
||||
Logs are ingested by DataDog through structured JSON output:
|
||||
|
||||
1. **Log Collection**: Container stdout/stderr → DataDog Agent → DataDog Logs
|
||||
2. **APM Traces**: Distributed tracing across services (when enabled)
|
||||
3. **Dashboards**: Custom dashboards for:
|
||||
- Error rates by service
|
||||
- Request latency percentiles
|
||||
- Conversation success rates
|
||||
- LLM cost tracking
|
||||
|
||||
### Conversation Stats Persistence
|
||||
|
||||
Per-conversation metrics are persisted for billing and analytics:
|
||||
|
||||
```python
|
||||
# Location: openhands/server/services/conversation_stats.py
|
||||
ConversationStats:
|
||||
- service_to_metrics: Dict[str, Metrics]
|
||||
- accumulated_cost: float
|
||||
- token_usage: TokenUsage
|
||||
|
||||
# Stored at: {file_store}/conversation_stats/{conversation_id}.pkl
|
||||
```
|
||||
64
openhands/architecture/system-architecture.md
Normal file
64
openhands/architecture/system-architecture.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# System Architecture Overview
|
||||
|
||||
OpenHands uses a multi-tier architecture with these main components:
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph AppServer["OpenHands App Server (Single Instance)"]
|
||||
API["REST API<br/>(FastAPI)"]
|
||||
Auth["Authentication"]
|
||||
ConvMgr["Conversation<br/>Manager"]
|
||||
SandboxSvc["Sandbox<br/>Service"]
|
||||
end
|
||||
|
||||
subgraph RuntimeAPI["Runtime API (Separate Service)"]
|
||||
RuntimeMgr["Runtime<br/>Manager"]
|
||||
WarmPool["Warm Runtime<br/>Pool"]
|
||||
end
|
||||
|
||||
subgraph Sandbox["Sandbox (Docker/K8s Container)"]
|
||||
AS["Agent Server<br/>(openhands-agent-server)"]
|
||||
AES["Action Execution<br/>Server"]
|
||||
Browser["Browser<br/>Environment"]
|
||||
FS["File System"]
|
||||
end
|
||||
|
||||
User["User"] -->|"1. HTTP/REST"| API
|
||||
API --> Auth
|
||||
Auth --> ConvMgr
|
||||
ConvMgr --> SandboxSvc
|
||||
|
||||
SandboxSvc -->|"2. POST /start"| RuntimeMgr
|
||||
RuntimeMgr -->|"Check pool"| WarmPool
|
||||
WarmPool -->|"Warm runtime<br/>available?"| RuntimeMgr
|
||||
RuntimeMgr -->|"3. Provision or<br/>assign runtime"| Sandbox
|
||||
|
||||
User -.->|"4. WebSocket<br/>(Direct)"| AS
|
||||
|
||||
AS -->|"HTTP"| AES
|
||||
AES --> Browser
|
||||
AES --> FS
|
||||
```
|
||||
|
||||
### Component Responsibilities
|
||||
|
||||
| Component | Location | Instances | Purpose |
|
||||
|-----------|----------|-----------|---------|
|
||||
| **App Server** | Host | 1 per deployment | REST API, auth, conversation management |
|
||||
| **Sandbox Service** | Inside App Server | 1 | Manages sandbox lifecycle, calls Runtime API |
|
||||
| **Runtime API** | Separate service | 1 per deployment | Provisions runtimes, manages warm pool |
|
||||
| **Agent Server** | Inside sandbox | 1 per sandbox | AI agent loop, LLM calls, state management |
|
||||
| **Action Execution Server** | Inside sandbox | 1 per sandbox | Execute bash, file ops, browser actions |
|
||||
|
||||
### Runtime API Endpoints
|
||||
|
||||
The Runtime API manages the actual container/pod lifecycle:
|
||||
|
||||
| Endpoint | Purpose |
|
||||
|----------|---------|
|
||||
| `POST /start` | Start a new runtime (or assign from warm pool) |
|
||||
| `POST /stop` | Stop and clean up a runtime |
|
||||
| `POST /pause` | Pause a running runtime |
|
||||
| `POST /resume` | Resume a paused runtime |
|
||||
| `GET /sessions/{id}` | Get runtime status |
|
||||
| `GET /list` | List all active runtimes |
|
||||
@@ -136,7 +136,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
if self.config.model.startswith('openhands/'):
|
||||
model_name = self.config.model.removeprefix('openhands/')
|
||||
self.config.model = f'litellm_proxy/{model_name}'
|
||||
self.config.base_url = _get_openhands_llm_base_url()
|
||||
self.config.base_url = 'https://llm-proxy.app.all-hands.dev/'
|
||||
logger.debug(
|
||||
f'Rewrote openhands/{model_name} to {self.config.model} with base URL {self.config.base_url}'
|
||||
)
|
||||
@@ -851,18 +851,3 @@ class LLM(RetryMixin, DebugMixin):
|
||||
|
||||
# let pydantic handle the serialization
|
||||
return [message.model_dump() for message in messages]
|
||||
|
||||
|
||||
def _get_openhands_llm_base_url():
|
||||
# Get the API url if specified
|
||||
lite_llm_api_url = os.getenv('LITE_LLM_API_URL')
|
||||
if lite_llm_api_url:
|
||||
return lite_llm_api_url
|
||||
|
||||
# Fallback to using web_host.
|
||||
web_host = os.getenv('WEB_HOST')
|
||||
if web_host and ('.staging.' in web_host or web_host.startswith('staging')):
|
||||
return 'https://llm-proxy.staging.all-hands.dev/'
|
||||
|
||||
# Use the default
|
||||
return 'https://llm-proxy.app.all-hands.dev/'
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@@ -33,9 +31,7 @@ class Settings(BaseModel):
|
||||
user_version: int | None = None
|
||||
remote_runtime_resource_factor: int | None = None
|
||||
# Planned to be removed from settings
|
||||
secrets_store: Annotated[Secrets, Field(frozen=True)] = Field(
|
||||
default_factory=Secrets
|
||||
)
|
||||
secrets_store: Secrets = Field(default_factory=Secrets, frozen=True)
|
||||
enable_default_condenser: bool = True
|
||||
enable_sound_notifications: bool = False
|
||||
enable_proactive_conversation_starters: bool = True
|
||||
|
||||
20
poetry.lock
generated
20
poetry.lock
generated
@@ -7731,14 +7731,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.9.1"
|
||||
version = "1.9.0"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.9.1-py3-none-any.whl", hash = "sha256:ea1457760505b9ebfe6aabea08dedd010ce93aeb93edb450f00e25a0d056a723"},
|
||||
{file = "openhands_agent_server-1.9.1.tar.gz", hash = "sha256:d92a29a9d5aa94207519a5f8daad7c0a3d6641d5cba9f763f25aa4e85713fa0f"},
|
||||
{file = "openhands_agent_server-1.9.0-py3-none-any.whl", hash = "sha256:44b65fac5bb831541eb2e8726afb2682bde4816b4c6c90be9ad3cafd3dbcf971"},
|
||||
{file = "openhands_agent_server-1.9.0.tar.gz", hash = "sha256:ac41a948acf64ed661a9f383c293c305176f92bd12e6fc6362f5414cb7874ee1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7755,14 +7755,14 @@ wsproto = ">=1.2.0"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.9.1"
|
||||
version = "1.9.0"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.9.1-py3-none-any.whl", hash = "sha256:0e732dfe0d91289536ea0410db9554d5a5b0326f60e547ea7a9d8ddab5fe93e4"},
|
||||
{file = "openhands_sdk-1.9.1.tar.gz", hash = "sha256:c6ba33f85efa4c2ec63eb1040cbe82839662bcbcf323654ed071a9ad38ce7994"},
|
||||
{file = "openhands_sdk-1.9.0-py3-none-any.whl", hash = "sha256:b427d8b9e587a5360c7d61742c290601998557e9b38b1c9e11a297659812c00d"},
|
||||
{file = "openhands_sdk-1.9.0.tar.gz", hash = "sha256:70048888fd4fbe44a86c35c402bbb99d30cf0cba50579ee1a8e3f43e05154150"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7783,14 +7783,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.9.1"
|
||||
version = "1.9.0"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.9.1-py3-none-any.whl", hash = "sha256:411819657e00ffac5d5b1ba9adc6eb65a0a17cbefb5e3e1a34bb132ff61c59f2"},
|
||||
{file = "openhands_tools-1.9.1.tar.gz", hash = "sha256:331608994cce22b662038a2fed0bf7d2c1bb8dc27b1fc0a12a646e9bd76e0843"},
|
||||
{file = "openhands_tools-1.9.0-py3-none-any.whl", hash = "sha256:8becde0e913a31babb41eb93a8c10bf41d87ca1febd07bc958839c3583655305"},
|
||||
{file = "openhands_tools-1.9.0.tar.gz", hash = "sha256:d45f5f5210cb2bbcd8ab5f3a32051db1a532d0ec07cd306105f95cde42cf67f2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -17367,4 +17367,4 @@ third-party-runtimes = ["daytona", "e2b-code-interpreter", "modal", "runloop-api
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "fecab94e6c18e6da0c67c3a249f20cd938b47a2faff492994311c36ac4e0019a"
|
||||
content-hash = "af2159c3b8723a036d7c3f3ddd0b45ce149acd20d164c17856be7db48a35c695"
|
||||
|
||||
@@ -54,9 +54,9 @@ dependencies = [
|
||||
"numpy",
|
||||
"openai==2.8",
|
||||
"openhands-aci==0.3.2",
|
||||
"openhands-agent-server==1.9.1",
|
||||
"openhands-sdk==1.9.1",
|
||||
"openhands-tools==1.9.1",
|
||||
"openhands-agent-server==1.9",
|
||||
"openhands-sdk==1.9",
|
||||
"openhands-tools==1.9",
|
||||
"opentelemetry-api>=1.33.1",
|
||||
"opentelemetry-exporter-otlp-proto-grpc>=1.33.1",
|
||||
"pathspec>=0.12.1",
|
||||
@@ -283,9 +283,9 @@ pybase62 = "^1.0.0"
|
||||
#openhands-agent-server = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-agent-server", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
#openhands-sdk = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-sdk", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
#openhands-tools = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-tools", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
|
||||
openhands-sdk = "1.9.1"
|
||||
openhands-agent-server = "1.9.1"
|
||||
openhands-tools = "1.9.1"
|
||||
openhands-sdk = "1.9.0"
|
||||
openhands-agent-server = "1.9.0"
|
||||
openhands-tools = "1.9.0"
|
||||
python-jose = { version = ">=3.3", extras = [ "cryptography" ] }
|
||||
sqlalchemy = { extras = [ "asyncio" ], version = "^2.0.40" }
|
||||
pg8000 = "^1.31.5"
|
||||
|
||||
@@ -17,7 +17,6 @@ from openhands.app_server.app_conversation.app_conversation_service_base import
|
||||
)
|
||||
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.sdk.context.skills import Skill
|
||||
|
||||
|
||||
class MockUserInfo:
|
||||
@@ -921,251 +920,347 @@ async def test_configure_git_user_settings_special_characters_in_name(mock_works
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for load_and_merge_all_skills (updated to use agent-server)
|
||||
# Tests for load_and_merge_all_skills with org skills
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMergeSkills:
|
||||
"""Test _merge_skills method."""
|
||||
class TestLoadAndMergeAllSkillsWithOrgSkills:
|
||||
"""Test load_and_merge_all_skills includes organization skills."""
|
||||
|
||||
def test_merges_skills_with_no_duplicates(self):
|
||||
"""Test merging skill lists with no duplicate names."""
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
|
||||
)
|
||||
async def test_load_and_merge_includes_org_skills(
|
||||
self,
|
||||
mock_load_repo,
|
||||
mock_load_org,
|
||||
mock_load_user,
|
||||
mock_load_global,
|
||||
mock_load_sandbox,
|
||||
):
|
||||
"""Test that load_and_merge_all_skills loads and merges org skills."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
|
||||
skill1 = Mock(spec=Skill)
|
||||
skill1.name = 'skill1'
|
||||
skill2 = Mock(spec=Skill)
|
||||
skill2.name = 'skill2'
|
||||
skill3 = Mock(spec=Skill)
|
||||
skill3.name = 'skill3'
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
sandbox.exposed_urls = []
|
||||
remote_workspace = AsyncMock()
|
||||
|
||||
skill_lists = [[skill1], [skill2], [skill3]]
|
||||
# Create distinct mock skills for each source
|
||||
sandbox_skill = Mock()
|
||||
sandbox_skill.name = 'sandbox_skill'
|
||||
global_skill = Mock()
|
||||
global_skill.name = 'global_skill'
|
||||
user_skill = Mock()
|
||||
user_skill.name = 'user_skill'
|
||||
org_skill = Mock()
|
||||
org_skill.name = 'org_skill'
|
||||
repo_skill = Mock()
|
||||
repo_skill.name = 'repo_skill'
|
||||
|
||||
mock_load_sandbox.return_value = [sandbox_skill]
|
||||
mock_load_global.return_value = [global_skill]
|
||||
mock_load_user.return_value = [user_skill]
|
||||
mock_load_org.return_value = [org_skill]
|
||||
mock_load_repo.return_value = [repo_skill]
|
||||
|
||||
# Act
|
||||
result = service._merge_skills(skill_lists)
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, remote_workspace, 'owner/repo', '/workspace'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert len(result) == 5
|
||||
names = {s.name for s in result}
|
||||
assert names == {'skill1', 'skill2', 'skill3'}
|
||||
|
||||
def test_merges_skills_with_duplicates_later_wins(self):
|
||||
"""Test that later skill lists override earlier ones for duplicate names."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
assert names == {
|
||||
'sandbox_skill',
|
||||
'global_skill',
|
||||
'user_skill',
|
||||
'org_skill',
|
||||
'repo_skill',
|
||||
}
|
||||
mock_load_org.assert_called_once_with(
|
||||
remote_workspace, 'owner/repo', '/workspace', mock_user_context
|
||||
)
|
||||
|
||||
skill1_v1 = Mock(spec=Skill)
|
||||
skill1_v1.name = 'skill1'
|
||||
skill1_v1.version = 'v1'
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
|
||||
)
|
||||
async def test_load_and_merge_org_skills_precedence(
|
||||
self,
|
||||
mock_load_repo,
|
||||
mock_load_org,
|
||||
mock_load_user,
|
||||
mock_load_global,
|
||||
mock_load_sandbox,
|
||||
):
|
||||
"""Test that org skills have correct precedence (higher than user, lower than repo)."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
|
||||
skill1_v2 = Mock(spec=Skill)
|
||||
skill1_v2.name = 'skill1'
|
||||
skill1_v2.version = 'v2'
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
sandbox.exposed_urls = []
|
||||
remote_workspace = AsyncMock()
|
||||
|
||||
skill2 = Mock(spec=Skill)
|
||||
skill2.name = 'skill2'
|
||||
# Create skills with same name but different sources
|
||||
user_skill = Mock()
|
||||
user_skill.name = 'common_skill'
|
||||
user_skill.source = 'user'
|
||||
|
||||
skill_lists = [[skill1_v1], [skill1_v2, skill2]]
|
||||
org_skill = Mock()
|
||||
org_skill.name = 'common_skill'
|
||||
org_skill.source = 'org'
|
||||
|
||||
repo_skill = Mock()
|
||||
repo_skill.name = 'common_skill'
|
||||
repo_skill.source = 'repo'
|
||||
|
||||
mock_load_sandbox.return_value = []
|
||||
mock_load_global.return_value = []
|
||||
mock_load_user.return_value = [user_skill]
|
||||
mock_load_org.return_value = [org_skill]
|
||||
mock_load_repo.return_value = [repo_skill]
|
||||
|
||||
# Act
|
||||
result = service._merge_skills(skill_lists)
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, remote_workspace, 'owner/repo', '/workspace'
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should have only one skill with repo source (highest precedence)
|
||||
assert len(result) == 1
|
||||
assert result[0].source == 'repo'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
|
||||
)
|
||||
async def test_load_and_merge_org_skills_override_user_skills(
|
||||
self,
|
||||
mock_load_repo,
|
||||
mock_load_org,
|
||||
mock_load_user,
|
||||
mock_load_global,
|
||||
mock_load_sandbox,
|
||||
):
|
||||
"""Test that org skills override user skills for same name."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
sandbox.exposed_urls = []
|
||||
remote_workspace = AsyncMock()
|
||||
|
||||
# Create skills with same name
|
||||
user_skill = Mock()
|
||||
user_skill.name = 'shared_skill'
|
||||
user_skill.priority = 'low'
|
||||
|
||||
org_skill = Mock()
|
||||
org_skill.name = 'shared_skill'
|
||||
org_skill.priority = 'high'
|
||||
|
||||
mock_load_sandbox.return_value = []
|
||||
mock_load_global.return_value = []
|
||||
mock_load_user.return_value = [user_skill]
|
||||
mock_load_org.return_value = [org_skill]
|
||||
mock_load_repo.return_value = []
|
||||
|
||||
# Act
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, remote_workspace, 'owner/repo', '/workspace'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].priority == 'high' # Org skill should win
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
|
||||
)
|
||||
async def test_load_and_merge_handles_org_skills_failure(
|
||||
self,
|
||||
mock_load_repo,
|
||||
mock_load_org,
|
||||
mock_load_user,
|
||||
mock_load_global,
|
||||
mock_load_sandbox,
|
||||
):
|
||||
"""Test that failure to load org skills doesn't break the overall process."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
sandbox.exposed_urls = []
|
||||
remote_workspace = AsyncMock()
|
||||
|
||||
global_skill = Mock()
|
||||
global_skill.name = 'global_skill'
|
||||
repo_skill = Mock()
|
||||
repo_skill.name = 'repo_skill'
|
||||
|
||||
mock_load_sandbox.return_value = []
|
||||
mock_load_global.return_value = [global_skill]
|
||||
mock_load_user.return_value = []
|
||||
mock_load_org.return_value = [] # Org skills failed/empty
|
||||
mock_load_repo.return_value = [repo_skill]
|
||||
|
||||
# Act
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, remote_workspace, 'owner/repo', '/workspace'
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should still have skills from other sources
|
||||
assert len(result) == 2
|
||||
skill1_result = next(s for s in result if s.name == 'skill1')
|
||||
assert skill1_result.version == 'v2'
|
||||
|
||||
|
||||
class TestLoadAndMergeAllSkills:
|
||||
"""Test load_and_merge_all_skills method (updated to use agent-server)."""
|
||||
names = {s.name for s in result}
|
||||
assert names == {'global_skill', 'repo_skill'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
|
||||
)
|
||||
async def test_loads_skills_successfully(
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
|
||||
)
|
||||
async def test_load_and_merge_no_selected_repository(
|
||||
self,
|
||||
mock_build_sandbox_config,
|
||||
mock_build_org_config,
|
||||
mock_load_skills,
|
||||
mock_load_repo,
|
||||
mock_load_org,
|
||||
mock_load_user,
|
||||
mock_load_global,
|
||||
mock_load_sandbox,
|
||||
):
|
||||
"""Test successfully loading skills from agent-server."""
|
||||
"""Test skill loading when no repository is selected."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
with patch.object(
|
||||
AppConversationServiceBase,
|
||||
'__abstractmethods__',
|
||||
set(),
|
||||
):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
init_git_in_empty_workspace=True,
|
||||
user_context=mock_user_context,
|
||||
)
|
||||
|
||||
mock_workspace = AsyncMock()
|
||||
mock_workspace.working_dir = '/workspace'
|
||||
|
||||
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
exposed_url = ExposedUrl(
|
||||
name='AGENT_SERVER', url='http://localhost:8000', port=8000
|
||||
)
|
||||
sandbox.exposed_urls = [exposed_url]
|
||||
sandbox.session_api_key = 'test-api-key'
|
||||
sandbox.exposed_urls = []
|
||||
remote_workspace = AsyncMock()
|
||||
|
||||
skill1 = Mock(spec=Skill)
|
||||
skill1.name = 'skill1'
|
||||
skill2 = Mock(spec=Skill)
|
||||
skill2.name = 'skill2'
|
||||
global_skill = Mock()
|
||||
global_skill.name = 'global_skill'
|
||||
|
||||
mock_load_skills.return_value = [skill1, skill2]
|
||||
mock_build_org_config.return_value = {'repository': 'owner/repo'}
|
||||
mock_build_sandbox_config.return_value = {'exposed_urls': []}
|
||||
mock_load_sandbox.return_value = []
|
||||
mock_load_global.return_value = [global_skill]
|
||||
mock_load_user.return_value = []
|
||||
mock_load_org.return_value = []
|
||||
mock_load_repo.return_value = []
|
||||
|
||||
# Act
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, 'owner/repo', '/workspace', 'http://localhost:8000'
|
||||
sandbox, remote_workspace, None, '/workspace'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].name == 'skill1'
|
||||
assert result[1].name == 'skill2'
|
||||
mock_load_skills.assert_called_once()
|
||||
call_kwargs = mock_load_skills.call_args[1]
|
||||
assert call_kwargs['agent_server_url'] == 'http://localhost:8000'
|
||||
assert call_kwargs['session_api_key'] == 'test-api-key'
|
||||
assert call_kwargs['project_dir'] == '/workspace/repo'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
|
||||
)
|
||||
async def test_returns_empty_list_when_no_agent_server_url(self, mock_load_skills):
|
||||
"""Test returns empty list when agent-server URL is not available."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
assert len(result) == 1
|
||||
# Org skills should be called even with None repository
|
||||
mock_load_org.assert_called_once_with(
|
||||
remote_workspace, None, '/workspace', mock_user_context
|
||||
)
|
||||
|
||||
AsyncMock()
|
||||
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
exposed_url = ExposedUrl(
|
||||
name='VSCODE', url='http://localhost:8080', port=8080
|
||||
)
|
||||
sandbox.exposed_urls = [exposed_url]
|
||||
|
||||
# Act - pass empty string to simulate no agent server URL
|
||||
# This should still call load_skills_from_agent_server but it will fail
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, 'owner/repo', '/workspace', ''
|
||||
)
|
||||
|
||||
# Assert - should return empty list when agent_server_url is empty
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
|
||||
)
|
||||
async def test_uses_working_dir_when_no_repository(
|
||||
self,
|
||||
mock_build_sandbox_config,
|
||||
mock_build_org_config,
|
||||
mock_load_skills,
|
||||
):
|
||||
"""Test uses working_dir as project_dir when no repository is selected."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
)
|
||||
|
||||
AsyncMock()
|
||||
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
exposed_url = ExposedUrl(
|
||||
name='AGENT_SERVER', url='http://localhost:8000', port=8000
|
||||
)
|
||||
sandbox.exposed_urls = [exposed_url]
|
||||
sandbox.session_api_key = 'test-key'
|
||||
|
||||
mock_load_skills.return_value = []
|
||||
mock_build_org_config.return_value = None
|
||||
mock_build_sandbox_config.return_value = None
|
||||
|
||||
# Act
|
||||
await service.load_and_merge_all_skills(
|
||||
sandbox, None, '/workspace', 'http://localhost:8000'
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_load_skills.call_args[1]
|
||||
assert call_kwargs['project_dir'] == '/workspace'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
|
||||
)
|
||||
@patch(
|
||||
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
|
||||
)
|
||||
async def test_handles_exception_gracefully(
|
||||
self,
|
||||
mock_build_sandbox_config,
|
||||
mock_build_org_config,
|
||||
mock_load_skills,
|
||||
):
|
||||
"""Test handles exceptions during skill loading."""
|
||||
# Arrange
|
||||
mock_user_context = Mock(spec=UserContext)
|
||||
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
|
||||
service = AppConversationServiceBase(
|
||||
init_git_in_empty_workspace=True, user_context=mock_user_context
|
||||
)
|
||||
|
||||
AsyncMock()
|
||||
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
|
||||
|
||||
sandbox = Mock(spec=SandboxInfo)
|
||||
exposed_url = ExposedUrl(
|
||||
name='AGENT_SERVER', url='http://localhost:8000', port=8000
|
||||
)
|
||||
sandbox.exposed_urls = [exposed_url]
|
||||
sandbox.session_api_key = 'test-key'
|
||||
|
||||
mock_load_skills.side_effect = Exception('Network error')
|
||||
|
||||
# Act
|
||||
result = await service.load_and_merge_all_skills(
|
||||
sandbox, 'owner/repo', '/workspace', 'http://localhost:8000'
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
@@ -1165,50 +1165,6 @@ class TestLiveStatusAppConversationService:
|
||||
)
|
||||
self.mock_event_service.search_events.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_conversation_calls_search_events_with_correct_parameter_name(
|
||||
self,
|
||||
):
|
||||
"""Test that export_conversation calls search_events with 'conversation_id' parameter, not 'conversation_id__eq'.
|
||||
|
||||
This test verifies the fix for a bug where page_iterator was called with
|
||||
conversation_id__eq instead of conversation_id, causing a TypeError since
|
||||
the search_events method expects conversation_id as its parameter name.
|
||||
"""
|
||||
# Arrange
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Mock conversation info
|
||||
mock_conversation_info = Mock(spec=AppConversationInfo)
|
||||
mock_conversation_info.id = conversation_id
|
||||
mock_conversation_info.model_dump_json = Mock(return_value='{}')
|
||||
|
||||
self.mock_app_conversation_info_service.get_app_conversation_info = AsyncMock(
|
||||
return_value=mock_conversation_info
|
||||
)
|
||||
|
||||
# Mock empty event page to simplify test
|
||||
mock_event_page = Mock()
|
||||
mock_event_page.items = []
|
||||
mock_event_page.next_page_id = None
|
||||
|
||||
self.mock_event_service.search_events = AsyncMock(return_value=mock_event_page)
|
||||
|
||||
# Act
|
||||
await self.service.export_conversation(conversation_id)
|
||||
|
||||
# Assert - Verify search_events was called with 'conversation_id', not 'conversation_id__eq'
|
||||
self.mock_event_service.search_events.assert_called()
|
||||
call_kwargs = self.mock_event_service.search_events.call_args[1]
|
||||
|
||||
assert 'conversation_id' in call_kwargs, (
|
||||
"search_events should be called with 'conversation_id' parameter"
|
||||
)
|
||||
assert 'conversation_id__eq' not in call_kwargs, (
|
||||
"search_events should NOT be called with 'conversation_id__eq' parameter"
|
||||
)
|
||||
assert call_kwargs['conversation_id'] == conversation_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_conversation_large_pagination(self):
|
||||
"""Test download with multiple pages of events."""
|
||||
@@ -1332,7 +1288,7 @@ class TestLiveStatusAppConversationService:
|
||||
task.sandbox_id = self.mock_sandbox.id
|
||||
yield task
|
||||
|
||||
async def mock_run_setup_scripts(task, sandbox, workspace, agent_server_url):
|
||||
async def mock_run_setup_scripts(task, sandbox, workspace):
|
||||
yield task
|
||||
|
||||
self.service._wait_for_sandbox_start = mock_wait_for_sandbox
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user