Compare commits

..

34 Commits

Author SHA1 Message Date
Chuck Butkus 4557d98cf9 Fix migration for byor key 2026-01-27 23:02:05 -05:00
Chuck Butkus f97117a1fe Fix Lint 2026-01-27 22:41:48 -05:00
openhands 89228a01c3 Add encrypt_legacy_model() and encrypt UserSettings fields in downgrade_user()
- Created encrypt_legacy_model(), encrypt_legacy_kwargs(), and encrypt_legacy_value()
  functions in encrypt_utils.py as the inverse of decrypt_legacy_* functions
- Updated downgrade_user() in UserStore to encrypt sensitive fields
  (llm_api_key, llm_api_key_for_byor, search_api_key, sandbox_api_key)
  before merging UserSettings into the database

Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-28 03:40:00 +00:00
Chuck Butkus b7b569993e Fix UserSettings creation from Org tables 2026-01-27 21:32:22 -05:00
Tim O'Farrell 102095affb Add downgrade script and methods for reverting user migration (#12629)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: chuckbutkus <chuck@all-hands.dev>
2026-01-27 14:41:34 -07:00
John-Mason P. Shackelford b6ce45b474 feat(app_server): start conversations with remote plugins via REST API (#12338)
- Add `PluginSpec` model with plugin configuration parameters extending SDK's `PluginSource`
- Extend app-conversations API to accept plugins specification in `AppConversationStartRequest`
- Propagate plugin source, ref, and repo_path to agent server's `StartConversationRequest`
- Include plugin parameters in initial conversation message for agent context

Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-27 16:26:38 -05:00
Tim O'Farrell 11c87caba4 fix(backend): fix callback state not persisting due to dual-session conflict (#12627)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-27 10:02:48 -07:00
Abhay Mishra b8a608c45e Fix: branch/repo dropdown reset on click (#12501)
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2026-01-27 16:08:08 +07:00
Tim O'Farrell 8a446787be Migrate SDK to 1.10.0 (#12614) 2026-01-27 04:26:06 +00:00
Tim O'Farrell 353124e171 Bump SDK to 1.10.0 (#12613) 2026-01-27 03:50:30 +00:00
chuckbutkus e9298c89bd Fix org migration (#12612)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-26 20:55:55 -05:00
Hiep Le 29b77be807 fix: revert get_user_by_id_async to use sync session_maker (#12610) 2026-01-27 04:39:07 +07:00
mamoodi 7094835ef0 Fix Pydantic UnsupportedFieldAttributeWarning in Settings model (#12600)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-26 12:30:43 -05:00
mamoodi 7ad0eec325 Change error to warning if TOS not accepted (#12605) 2026-01-26 12:30:00 -05:00
Hiep Le 31d5081163 feat(frontend): display plan preview content (#12504) 2026-01-26 23:19:57 +07:00
Abhay Mishra 250736cb7a fix(frontend): display ThinkAction thought content in V1 UI (#12597)
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2026-01-26 14:03:16 +07:00
Tim O'Farrell a9bd3a70c9 Fix V0 Integrations (#12584) 2026-01-24 16:53:16 +00:00
Hiep Le d7436a4af4 fix(backend): asyncsession query error in userstore.get_user_by_id_async (#12586)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-24 23:47:04 +07:00
Tim O'Farrell f327e76be7 Added explicit expired error (#12580) 2026-01-23 12:49:10 -07:00
Hiep Le 52e39e5d12 fix(backend): unable to export conversation (#12577)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-24 01:06:02 +07:00
Graham Neubig 6c5ef256fd fix: pass userId to EmailVerificationModal in login page (#12573)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2026-01-23 23:24:25 +07:00
Tim O'Farrell 373ade8554 Fix org billing (#12562) 2026-01-23 08:42:50 -07:00
Hiep Le 9d0a19cf8f fix(backend): ensure conversation events are written back to google cloud (#12571) 2026-01-23 22:13:08 +07:00
Rohit Malhotra d60dd38d78 fix: preserve query params in returnTo during login redirect (#12567)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-22 21:31:42 -08:00
Hiep Le d5ee799670 feat(backend): develop patch /api/organizations/{orgid} api (#12470)
Co-authored-by: rohitvinodmalhotra@gmail.com <rohitvinodmalhotra@gmail.com>
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Chuck Butkus <chuck@all-hands.dev>
2026-01-23 01:29:35 +07:00
Hiep Le b685fd43dd fix(backend): github proxy state compression (#12387)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Ray Myers <ray.myers@gmail.com>
2026-01-23 00:39:03 +07:00
Hiep Le 0e04f6fdbe feat(backend): develop delete /api/organizations/{orgid} api (#12471)
Co-authored-by: rohitvinodmalhotra@gmail.com <rohitvinodmalhotra@gmail.com>
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Chuck Butkus <chuck@all-hands.dev>
2026-01-23 00:28:55 +07:00
Hiep Le 9c40929197 feat(backend): develop get /api/organizations/{orgid} api (#12274)
Co-authored-by: rohitvinodmalhotra@gmail.com <rohitvinodmalhotra@gmail.com>
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Chuck Butkus <chuck@all-hands.dev>
Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
2026-01-22 23:55:29 +07:00
Hiep Le af309e8586 feat(backend): develop get /api/organizations api (#12373)
Co-authored-by: rohitvinodmalhotra@gmail.com <rohitvinodmalhotra@gmail.com>
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Chuck Butkus <chuck@all-hands.dev>
2026-01-22 23:32:42 +07:00
sp.wack cc5d5c2335 test(frontend): add missing HTTP mocks for conversation history preloading (#12549) 2026-01-22 20:30:27 +04:00
Hiep Le 60e668f4a7 fix(backend): application settings are not updating as expected (#12547) 2026-01-22 21:54:52 +07:00
Hiep Le 743f6256a6 feat(backend): load skills from agent server (#12434) 2026-01-22 20:20:50 +07:00
Mohammed Abdulai a87b4efd41 feat: preload conversation history before websocket connection (#12488)
Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
2026-01-22 12:41:27 +00:00
Tim O'Farrell 730d9970f5 Migrate SDK to 1.9.1 (#12540) 2026-01-21 16:14:27 -07:00
82 changed files with 8870 additions and 2818 deletions
+1 -1
View File
@@ -8,7 +8,7 @@ services:
container_name: openhands-app-${DATE:-}
environment:
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-0fdea73-python}
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-31536c8-python}
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
ports:
+205
View File
@@ -0,0 +1,205 @@
#!/usr/bin/env python
"""
Downgrade script for migrated users.
This script identifies users who have been migrated (already_migrated=True)
and reverts them back to the pre-migration state.
Usage:
# Dry run - just list the users that would be downgraded
python downgrade_migrated_users.py --dry-run
# Downgrade a specific user by their keycloak_user_id
python downgrade_migrated_users.py --user-id <user_id>
# Downgrade all migrated users (with confirmation)
python downgrade_migrated_users.py --all
# Downgrade all migrated users without confirmation (dangerous!)
python downgrade_migrated_users.py --all --no-confirm
"""
import argparse
import asyncio
import sys
# Add the enterprise directory to the path
sys.path.insert(0, '/workspace/project/OpenHands/enterprise')
from server.logger import logger
from sqlalchemy import select, text
from storage.database import session_maker
from storage.user_settings import UserSettings
from storage.user_store import UserStore
def get_migrated_users() -> list[str]:
"""Get list of keycloak_user_ids for users who have been migrated.
This includes:
1. Users with already_migrated=True in user_settings (migrated users)
2. Users in the 'user' table who don't have a user_settings entry (new sign-ups)
"""
with session_maker() as session:
# Get users from user_settings with already_migrated=True
migrated_result = session.execute(
select(UserSettings.keycloak_user_id).where(
UserSettings.already_migrated.is_(True)
)
)
migrated_users = {row[0] for row in migrated_result.fetchall() if row[0]}
# Get users from the 'user' table (new sign-ups won't have user_settings)
# These are users who signed up after the migration was deployed
new_signup_result = session.execute(
text("""
SELECT CAST(u.id AS VARCHAR)
FROM "user" u
WHERE NOT EXISTS (
SELECT 1 FROM user_settings us
WHERE us.keycloak_user_id = CAST(u.id AS VARCHAR)
)
""")
)
new_signups = {row[0] for row in new_signup_result.fetchall() if row[0]}
# Combine both sets
all_users = migrated_users | new_signups
return list(all_users)
async def downgrade_user(user_id: str) -> bool:
"""Downgrade a single user.
Args:
user_id: The keycloak_user_id to downgrade
Returns:
True if successful, False otherwise
"""
try:
result = await UserStore.downgrade_user(user_id)
if result:
print(f'✓ Successfully downgraded user: {user_id}')
return True
else:
print(f'✗ Failed to downgrade user: {user_id}')
return False
except Exception as e:
print(f'✗ Error downgrading user {user_id}: {e}')
logger.exception(
'downgrade_script:error',
extra={'user_id': user_id, 'error': str(e)},
)
return False
async def main():
parser = argparse.ArgumentParser(
description='Downgrade migrated users back to pre-migration state'
)
parser.add_argument(
'--dry-run',
action='store_true',
help='Just list users that would be downgraded, without making changes',
)
parser.add_argument(
'--user-id',
type=str,
help='Downgrade a specific user by keycloak_user_id',
)
parser.add_argument(
'--all',
action='store_true',
help='Downgrade all migrated users',
)
parser.add_argument(
'--no-confirm',
action='store_true',
help='Skip confirmation prompt (use with caution!)',
)
args = parser.parse_args()
# Get list of migrated users
migrated_users = get_migrated_users()
print(f'\nFound {len(migrated_users)} migrated user(s).')
if args.dry_run:
print('\n--- DRY RUN MODE ---')
print('The following users would be downgraded:')
for user_id in migrated_users:
print(f' - {user_id}')
print('\nNo changes were made.')
return
if args.user_id:
# Downgrade a specific user
if args.user_id not in migrated_users:
print(f'\nUser {args.user_id} is not in the migrated users list.')
print('Either the user was not migrated, or the user_id is incorrect.')
return
print(f'\nDowngrading user: {args.user_id}')
if not args.no_confirm:
confirm = input('Are you sure? (yes/no): ')
if confirm.lower() != 'yes':
print('Cancelled.')
return
success = await downgrade_user(args.user_id)
if success:
print('\nDowngrade completed successfully.')
else:
print('\nDowngrade failed. Check logs for details.')
sys.exit(1)
elif args.all:
# Downgrade all migrated users
if not migrated_users:
print('\nNo migrated users to downgrade.')
return
print(f'\n⚠️ About to downgrade {len(migrated_users)} user(s).')
if not args.no_confirm:
print('\nThis will:')
print(' - Revert LiteLLM team/user budget settings')
print(' - Delete organization entries')
print(' - Delete user entries in the new schema')
print(' - Reset the already_migrated flag')
print('\nUsers to downgrade:')
for user_id in migrated_users[:10]: # Show first 10
print(f' - {user_id}')
if len(migrated_users) > 10:
print(f' ... and {len(migrated_users) - 10} more')
confirm = input('\nType "yes" to proceed: ')
if confirm.lower() != 'yes':
print('Cancelled.')
return
print('\nStarting downgrade...\n')
success_count = 0
fail_count = 0
for user_id in migrated_users:
success = await downgrade_user(user_id)
if success:
success_count += 1
else:
fail_count += 1
print('\n--- Summary ---')
print(f'Successful: {success_count}')
print(f'Failed: {fail_count}')
if fail_count > 0:
sys.exit(1)
else:
parser.print_help()
print('\nPlease specify --dry-run, --user-id, or --all')
if __name__ == '__main__':
asyncio.run(main())
@@ -26,12 +26,14 @@ from integrations.utils import (
from integrations.v1_utils import get_saas_user_auth
from jinja2 import Environment, FileSystemLoader
from pydantic import SecretStr
from server.auth.auth_error import ExpiredError
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
from server.auth.token_manager import TokenManager
from server.utils.conversation_callback_utils import register_callback_processor
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.integrations.service_types import AuthenticationError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
@@ -347,7 +349,7 @@ class GithubManager(Manager):
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
except SessionExpiredError as e:
except (AuthenticationError, ExpiredError, SessionExpiredError) as e:
logger.warning(
f'[GitHub] Session expired for user {user_info.username}: {str(e)}'
)
@@ -0,0 +1,28 @@
"""Add git_user_name and git_user_email columns to user table.
Revision ID: 090
Revises: 089
Create Date: 2025-01-22
"""
import sqlalchemy as sa
from alembic import op
revision = '090'
down_revision = '089'
def upgrade() -> None:
op.add_column(
'user',
sa.Column('git_user_name', sa.String, nullable=True),
)
op.add_column(
'user',
sa.Column('git_user_email', sa.String, nullable=True),
)
def downgrade() -> None:
op.drop_column('user', 'git_user_email')
op.drop_column('user', 'git_user_name')
+12 -12
View File
@@ -6102,14 +6102,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
[[package]]
name = "openhands-agent-server"
version = "1.9.0"
version = "1.10.0"
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_agent_server-1.9.0-py3-none-any.whl", hash = "sha256:44b65fac5bb831541eb2e8726afb2682bde4816b4c6c90be9ad3cafd3dbcf971"},
{file = "openhands_agent_server-1.9.0.tar.gz", hash = "sha256:ac41a948acf64ed661a9f383c293c305176f92bd12e6fc6362f5414cb7874ee1"},
{file = "openhands_agent_server-1.10.0-py3-none-any.whl", hash = "sha256:2e21076fff5e7cf9d03a3b011e2c90a6a3a46d2da3f18db9f7553ac413229c22"},
{file = "openhands_agent_server-1.10.0.tar.gz", hash = "sha256:2062da2496a98a6c23201d086f124e02329d6c6d9d1b47be55921c084a29f55a"},
]
[package.dependencies]
@@ -6168,9 +6168,9 @@ memory-profiler = ">=0.61"
numpy = "*"
openai = "2.8"
openhands-aci = "0.3.2"
openhands-agent-server = "1.9"
openhands-sdk = "1.9"
openhands-tools = "1.9"
openhands-agent-server = "1.10"
openhands-sdk = "1.10"
openhands-tools = "1.10"
opentelemetry-api = ">=1.33.1"
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
pathspec = ">=0.12.1"
@@ -6225,14 +6225,14 @@ url = ".."
[[package]]
name = "openhands-sdk"
version = "1.9.0"
version = "1.10.0"
description = "OpenHands SDK - Core functionality for building AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_sdk-1.9.0-py3-none-any.whl", hash = "sha256:b427d8b9e587a5360c7d61742c290601998557e9b38b1c9e11a297659812c00d"},
{file = "openhands_sdk-1.9.0.tar.gz", hash = "sha256:70048888fd4fbe44a86c35c402bbb99d30cf0cba50579ee1a8e3f43e05154150"},
{file = "openhands_sdk-1.10.0-py3-none-any.whl", hash = "sha256:5c8875f2a07d7fabe3449914639572bef9003821207cb06aa237a239e964eed5"},
{file = "openhands_sdk-1.10.0.tar.gz", hash = "sha256:93371b1af4532266ad2d225b9d7d3d711c745df31888efe643970673f62bdef9"},
]
[package.dependencies]
@@ -6253,14 +6253,14 @@ boto3 = ["boto3 (>=1.35.0)"]
[[package]]
name = "openhands-tools"
version = "1.9.0"
version = "1.10.0"
description = "OpenHands Tools - Runtime tools for AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_tools-1.9.0-py3-none-any.whl", hash = "sha256:8becde0e913a31babb41eb93a8c10bf41d87ca1febd07bc958839c3583655305"},
{file = "openhands_tools-1.9.0.tar.gz", hash = "sha256:d45f5f5210cb2bbcd8ab5f3a32051db1a532d0ec07cd306105f95cde42cf67f2"},
{file = "openhands_tools-1.10.0-py3-none-any.whl", hash = "sha256:1d5d2d1e34cc4ceb02c0ff1f008b06883ad48a8e7236ab8dd61ece64fbf8e2ed"},
{file = "openhands_tools-1.10.0.tar.gz", hash = "sha256:7ed38cb13545ec2c4a35c26ece725d5b35788d30597db8b1904619c043ec1194"},
]
[package.dependencies]
+3
View File
@@ -16,6 +16,7 @@ from keycloak.exceptions import (
KeycloakError,
KeycloakPostError,
)
from server.auth.auth_error import ExpiredError
from server.auth.constants import (
BITBUCKET_APP_CLIENT_ID,
BITBUCKET_APP_CLIENT_SECRET,
@@ -426,6 +427,8 @@ class TokenManager:
access_token = data.get('access_token')
refresh_token = data.get('refresh_token')
if not access_token or not refresh_token:
if data.get('error') == 'bad_refresh_token':
raise ExpiredError()
raise ValueError(
'Failed to refresh token: missing access_token or refresh_token in response.'
)
@@ -14,7 +14,6 @@ from storage.conversation_callback import (
ConversationCallback,
ConversationCallbackProcessor,
)
from storage.database import session_maker
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
@@ -108,13 +107,10 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
f'[GitHub] Sent summary instruction to conversation {conversation_id} {summary_event}'
)
# Update the processor state
# Update the processor state - the outer session will commit this
self.send_summary_instruction = False
callback.set_processor(self)
callback.updated_at = datetime.now()
with session_maker() as session:
session.merge(callback)
session.commit()
return
# Extract the summary from the event store
@@ -130,14 +126,15 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
logger.info(f'[GitHub] Summary sent for conversation {conversation_id}')
# Mark callback as completed status
# Mark callback as completed status - the outer session will commit this
callback.status = CallbackStatus.COMPLETED
callback.updated_at = datetime.now()
with session_maker() as session:
session.merge(callback)
session.commit()
except Exception as e:
logger.exception(
f'[GitHub] Error processing conversation callback: {str(e)}'
)
# Mark callback as error to prevent infinite re-invocation
# The outer session will commit this
callback.status = CallbackStatus.ERROR
callback.updated_at = datetime.now()
+1 -1
View File
@@ -144,7 +144,7 @@ class SetAuthCookieMiddleware:
# "if accepted_tos is not None" as there should not be any users with
# accepted_tos equal to "None"
if accepted_tos is False and request.url.path != '/api/accept_tos':
logger.error('User has not accepted the terms of service')
logger.warning('User has not accepted the terms of service')
raise TosNotAcceptedError
def _should_attach(self, request: Request) -> bool:
+29 -33
View File
@@ -13,46 +13,33 @@ from server.constants import (
STRIPE_API_KEY,
)
from server.logger import logger
from starlette.datastructures import URL
from storage.billing_session import BillingSession
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.subscription_access import SubscriptionAccess
from storage.user_store import UserStore
from openhands.app_server.config import get_global_config
from openhands.server.user_auth import get_user_id
stripe.api_key = STRIPE_API_KEY
billing_router = APIRouter(prefix='/api/billing')
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
def is_all_hands_saas_environment(request: Request) -> bool:
"""Check if the current domain is an All Hands SaaS environment.
Args:
request: FastAPI Request object
Returns:
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
async def validate_billing_enabled() -> None:
"""
hostname = request.url.hostname or ''
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
def validate_saas_environment(request: Request) -> None:
"""Validate that the request is coming from an All Hands SaaS environment.
Args:
request: FastAPI Request object
Raises:
HTTPException: If the request is not from an All Hands SaaS environment
Validate that the billing feature flag is enabled
"""
if not is_all_hands_saas_environment(request):
config = get_global_config()
web_client_config = await config.web_client.get_web_client_config()
if not web_client_config.feature_flags.enable_billing:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='Checkout sessions are only available for All Hands SaaS environments',
detail=(
'Billing is disabled in this environment. '
'Please set OH_WEB_CLIENT_FEATURE_FLAGS_ENABLE_BILLING to enable billing.'
),
)
@@ -154,14 +141,15 @@ async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
async def create_customer_setup_session(
request: Request, user_id: str = Depends(get_user_id)
) -> CreateBillingSessionResponse:
validate_saas_environment(request)
await validate_billing_enabled()
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
base_url = _get_base_url(request)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_info['customer_id'],
mode='setup',
payment_method_types=['card'],
success_url=f'{request.base_url}?free_credits=success',
cancel_url=f'{request.base_url}',
success_url=f'{base_url}?free_credits=success',
cancel_url=f'{base_url}',
)
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
@@ -173,8 +161,8 @@ async def create_checkout_session(
request: Request,
user_id: str = Depends(get_user_id),
) -> CreateBillingSessionResponse:
validate_saas_environment(request)
await validate_billing_enabled()
base_url = _get_base_url(request)
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_info['customer_id'],
@@ -197,8 +185,8 @@ async def create_checkout_session(
saved_payment_method_options={
'payment_method_save': 'enabled',
},
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
)
logger.info(
'created_stripe_checkout_session',
@@ -289,7 +277,7 @@ async def success_callback(session_id: str, request: Request):
session.commit()
return RedirectResponse(
f'{request.base_url}settings/billing?checkout=success', status_code=302
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
)
@@ -317,5 +305,13 @@ async def cancel_callback(session_id: str, request: Request):
session.commit()
return RedirectResponse(
f'{request.base_url}settings/billing?checkout=cancel', status_code=302
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
)
def _get_base_url(request: Request) -> URL:
# Never send any part of the credit card process over a non secure connection
base_url = request.base_url
if base_url.hostname != 'localhost':
base_url = base_url.replace(scheme='https')
return base_url
+9 -2
View File
@@ -1,6 +1,7 @@
import hashlib
import json
import os
import zlib
from base64 import b64decode, b64encode
from urllib.parse import parse_qs, urlencode, urlparse
@@ -51,7 +52,11 @@ def add_github_proxy_routes(app: FastAPI):
state_payload = json.dumps(
[query_params['state'][0], query_params['redirect_uri'][0]]
)
state = b64encode(_fernet().encrypt(state_payload.encode())).decode()
# Compress before encrypting to reduce URL length
# This is critical for feature deployments where reCAPTCHA tokens in state
# can cause "URL too long" errors from GitHub
compressed_payload = zlib.compress(state_payload.encode())
state = b64encode(_fernet().encrypt(compressed_payload)).decode()
query_params['state'] = [state]
query_params['redirect_uri'] = [
f'https://{request.url.netloc}/github-proxy/callback'
@@ -67,7 +72,9 @@ def add_github_proxy_routes(app: FastAPI):
parsed_url = urlparse(str(request.url))
query_params = parse_qs(parsed_url.query)
state = query_params['state'][0]
decrypted_state = _fernet().decrypt(b64decode(state.encode())).decode()
# Decrypt and decompress (reverse of github_proxy_start)
decrypted_payload = _fernet().decrypt(b64decode(state.encode()))
decrypted_state = zlib.decompress(decrypted_payload).decode()
# Build query Params
state, redirect_uri = json.loads(decrypted_state)
+104
View File
@@ -1,4 +1,5 @@
from pydantic import BaseModel, EmailStr, Field
from storage.org import Org
class OrgCreationError(Exception):
@@ -27,6 +28,27 @@ class OrgDatabaseError(OrgCreationError):
pass
class OrgDeletionError(Exception):
"""Base exception for organization deletion errors."""
pass
class OrgAuthorizationError(OrgDeletionError):
"""Raised when user is not authorized to delete organization."""
def __init__(self, message: str = 'Not authorized to delete organization'):
super().__init__(message)
class OrgNotFoundError(Exception):
"""Raised when organization is not found or user doesn't have access."""
def __init__(self, org_id: str):
self.org_id = org_id
super().__init__(f'Organization with id "{org_id}" not found')
class OrgCreate(BaseModel):
"""Request model for creating a new organization."""
@@ -65,3 +87,85 @@ class OrgResponse(BaseModel):
enable_solvability_analysis: bool | None = None
v1_enabled: bool | None = None
credits: float | None = None
@classmethod
def from_org(cls, org: Org, credits: float | None = None) -> 'OrgResponse':
"""Create an OrgResponse from an Org entity.
Args:
org: The organization entity to convert
credits: Optional credits value (defaults to None)
Returns:
OrgResponse: The response model instance
"""
return cls(
id=str(org.id),
name=org.name,
contact_name=org.contact_name,
contact_email=org.contact_email,
conversation_expiration=org.conversation_expiration,
agent=org.agent,
default_max_iterations=org.default_max_iterations,
security_analyzer=org.security_analyzer,
confirmation_mode=org.confirmation_mode,
default_llm_model=org.default_llm_model,
default_llm_api_key_for_byor=None,
default_llm_base_url=org.default_llm_base_url,
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
enable_default_condenser=org.enable_default_condenser
if org.enable_default_condenser is not None
else True,
billing_margin=org.billing_margin,
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters
if org.enable_proactive_conversation_starters is not None
else True,
sandbox_base_container_image=org.sandbox_base_container_image,
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
org_version=org.org_version if org.org_version is not None else 0,
mcp_config=org.mcp_config,
search_api_key=None,
sandbox_api_key=None,
max_budget_per_task=org.max_budget_per_task,
enable_solvability_analysis=org.enable_solvability_analysis,
v1_enabled=org.v1_enabled,
credits=credits,
)
class OrgPage(BaseModel):
"""Paginated response model for organization list."""
items: list[OrgResponse]
next_page_id: str | None = None
class OrgUpdate(BaseModel):
"""Request model for updating an organization."""
# Basic organization information (any authenticated user can update)
contact_name: str | None = None
contact_email: EmailStr | None = Field(default=None, strip_whitespace=True)
conversation_expiration: int | None = None
default_max_iterations: int | None = Field(default=None, gt=0)
remote_runtime_resource_factor: int | None = Field(default=None, gt=0)
billing_margin: float | None = Field(default=None, ge=0, le=1)
enable_proactive_conversation_starters: bool | None = None
sandbox_base_container_image: str | None = None
sandbox_runtime_container_image: str | None = None
mcp_config: dict | None = None
sandbox_api_key: str | None = None
max_budget_per_task: float | None = Field(default=None, gt=0)
enable_solvability_analysis: bool | None = None
v1_enabled: bool | None = None
# LLM settings (require admin/owner role)
default_llm_model: str | None = None
default_llm_api_key_for_byor: str | None = None
default_llm_base_url: str | None = None
search_api_key: str | None = None
security_analyzer: str | None = None
agent: str | None = None
confirmation_mode: bool | None = None
enable_default_condenser: bool | None = None
condenser_max_size: int | None = Field(default=None, ge=20)
+311 -26
View File
@@ -1,20 +1,98 @@
from fastapi import APIRouter, Depends, HTTPException, status
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from server.email_validation import get_admin_user_id
from server.routes.org_models import (
LiteLLMIntegrationError,
OrgAuthorizationError,
OrgCreate,
OrgDatabaseError,
OrgNameExistsError,
OrgNotFoundError,
OrgPage,
OrgResponse,
OrgUpdate,
)
from storage.org_service import OrgService
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
# Initialize API router
org_router = APIRouter(prefix='/api/organizations')
@org_router.get('', response_model=OrgPage)
async def list_user_orgs(
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
] = 100,
user_id: str = Depends(get_user_id),
) -> OrgPage:
"""List organizations for the authenticated user.
This endpoint returns a paginated list of all organizations that the
authenticated user is a member of.
Args:
page_id: Optional page ID (offset) for pagination
limit: Maximum number of organizations to return (1-100, default 100)
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgPage: Paginated list of organizations
Raises:
HTTPException: 500 if retrieval fails
"""
logger.info(
'Listing organizations for user',
extra={
'user_id': user_id,
'page_id': page_id,
'limit': limit,
},
)
try:
# Fetch organizations from service layer
orgs, next_page_id = OrgService.get_user_orgs_paginated(
user_id=user_id,
page_id=page_id,
limit=limit,
)
# Convert Org entities to OrgResponse objects
org_responses = [OrgResponse.from_org(org, credits=None) for org in orgs]
logger.info(
'Successfully retrieved organizations',
extra={
'user_id': user_id,
'org_count': len(org_responses),
'has_more': next_page_id is not None,
},
)
return OrgPage(items=org_responses, next_page_id=next_page_id)
except Exception as e:
logger.exception(
'Unexpected error listing organizations',
extra={'user_id': user_id, 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to retrieve organizations',
)
@org_router.post('', response_model=OrgResponse, status_code=status.HTTP_201_CREATED)
async def create_org(
org_data: OrgCreate,
@@ -58,31 +136,7 @@ async def create_org(
# Retrieve credits from LiteLLM
credits = await OrgService.get_org_credits(user_id, org.id)
return OrgResponse(
id=str(org.id),
name=org.name,
contact_name=org.contact_name,
contact_email=org.contact_email,
conversation_expiration=org.conversation_expiration,
agent=org.agent,
default_max_iterations=org.default_max_iterations,
security_analyzer=org.security_analyzer,
confirmation_mode=org.confirmation_mode,
default_llm_model=org.default_llm_model,
default_llm_base_url=org.default_llm_base_url,
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
enable_default_condenser=org.enable_default_condenser,
billing_margin=org.billing_margin,
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
sandbox_base_container_image=org.sandbox_base_container_image,
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
org_version=org.org_version,
mcp_config=org.mcp_config,
max_budget_per_task=org.max_budget_per_task,
enable_solvability_analysis=org.enable_solvability_analysis,
v1_enabled=org.v1_enabled,
credits=credits,
)
return OrgResponse.from_org(org, credits=credits)
except OrgNameExistsError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
@@ -115,3 +169,234 @@ async def create_org(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
async def get_org(
org_id: UUID,
user_id: str = Depends(get_user_id),
) -> OrgResponse:
"""Get organization details by ID.
This endpoint allows authenticated users who are members of an organization
to retrieve its details. Only members of the organization can access this endpoint.
Args:
org_id: Organization ID (UUID)
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgResponse: The organization details
Raises:
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
HTTPException: 404 if organization not found or user is not a member
HTTPException: 500 if retrieval fails
"""
logger.info(
'Retrieving organization details',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
try:
# Use service layer to get organization with membership validation
org = await OrgService.get_org_by_id(
org_id=org_id,
user_id=user_id,
)
# Retrieve credits from LiteLLM
credits = await OrgService.get_org_credits(user_id, org.id)
return OrgResponse.from_org(org, credits=credits)
except OrgNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except Exception as e:
logger.exception(
'Unexpected error retrieving organization',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
async def delete_org(
org_id: UUID,
user_id: str = Depends(get_admin_user_id),
) -> dict:
"""Delete an organization.
This endpoint allows authenticated organization owners to delete their organization.
All associated data including organization members, conversations, billing data,
and external LiteLLM team resources will be permanently removed.
Args:
org_id: Organization ID to delete
user_id: Authenticated user ID (injected by dependency)
Returns:
dict: Confirmation message with deleted organization details
Raises:
HTTPException: 403 if user is not the organization owner
HTTPException: 404 if organization not found
HTTPException: 500 if deletion fails
"""
logger.info(
'Organization deletion requested',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
try:
# Use service layer to delete organization with cleanup
deleted_org = await OrgService.delete_org_with_cleanup(
user_id=user_id,
org_id=org_id,
)
logger.info(
'Organization deletion completed successfully',
extra={
'user_id': user_id,
'org_id': str(org_id),
'org_name': deleted_org.name,
},
)
return {
'message': 'Organization deleted successfully',
'organization': {
'id': str(deleted_org.id),
'name': deleted_org.name,
'contact_name': deleted_org.contact_name,
'contact_email': deleted_org.contact_email,
},
}
except OrgNotFoundError as e:
logger.warning(
'Organization not found for deletion',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except OrgAuthorizationError as e:
logger.warning(
'User not authorized to delete organization',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except OrgDatabaseError as e:
logger.error(
'Database error during organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to delete organization',
)
except Exception as e:
logger.exception(
'Unexpected error during organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@org_router.patch('/{org_id}', response_model=OrgResponse)
async def update_org(
org_id: UUID,
update_data: OrgUpdate,
user_id: str = Depends(get_user_id),
) -> OrgResponse:
"""Update an existing organization.
This endpoint allows authenticated users to update organization settings.
LLM-related settings require admin or owner role in the organization.
Args:
org_id: Organization ID to update (UUID validated by FastAPI)
update_data: Organization update data
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgResponse: The updated organization details
Raises:
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
HTTPException: 403 if user lacks permission for LLM settings
HTTPException: 404 if organization not found
HTTPException: 422 if validation errors occur (handled by FastAPI)
HTTPException: 500 if update fails
"""
logger.info(
'Updating organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
try:
# Use service layer to update organization with permission checks
updated_org = await OrgService.update_org_with_permissions(
org_id=org_id,
update_data=update_data,
user_id=user_id,
)
# Retrieve credits from LiteLLM (following same pattern as create endpoint)
credits = await OrgService.get_org_credits(user_id, updated_org.id)
return OrgResponse.from_org(updated_org, credits=credits)
except ValueError as e:
# Organization not found
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except PermissionError as e:
# User lacks permission for LLM settings
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except OrgDatabaseError as e:
logger.error(
'Database operation failed',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to update organization',
)
except Exception as e:
logger.exception(
'Unexpected error updating organization',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@@ -26,6 +26,7 @@ from server.sharing.shared_conversation_models import (
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
@@ -57,7 +58,7 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
query = self._public_select()
query = self._public_select_with_saas_metadata()
# Conditionally exclude sub-conversations based on the parameter
if not include_sub_conversations:
@@ -104,14 +105,17 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
query = query.limit(limit + 1)
result = await self.db_session.execute(query)
rows = result.scalars().all()
rows = result.all()
# Check if there are more results
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
items = [self._to_shared_conversation(row) for row in rows]
items = [
self._to_shared_conversation(stored, saas_metadata=saas_metadata)
for stored, saas_metadata in rows
]
# Calculate next page ID
next_page_id = None
@@ -152,17 +156,18 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single public conversation info, returning None if missing or not shared."""
query = self._public_select().where(
query = self._public_select_with_saas_metadata().where(
StoredConversationMetadata.conversation_id == str(conversation_id)
)
result = await self.db_session.execute(query)
stored = result.scalar_one_or_none()
row = result.first()
if stored is None:
if row is None:
return None
return self._to_shared_conversation(stored)
stored, saas_metadata = row
return self._to_shared_conversation(stored, saas_metadata=saas_metadata)
def _public_select(self):
"""Create a select query that only returns public conversations."""
@@ -173,6 +178,25 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
return query
def _public_select_with_saas_metadata(self):
"""Create a select query that returns public conversations with SAAS metadata.
This joins with conversation_metadata_saas to retrieve the user_id needed
for constructing the correct event storage path. Uses LEFT OUTER JOIN to
support conversations that may not have SAAS metadata (e.g., in tests).
"""
query = (
select(StoredConversationMetadata, StoredConversationMetadataSaas)
.outerjoin(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.where(StoredConversationMetadata.conversation_version == 'V1')
.where(StoredConversationMetadata.public == True) # noqa: E712
)
return query
def _apply_filters(
self,
query,
@@ -211,9 +235,16 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
def _to_shared_conversation(
self,
stored: StoredConversationMetadata,
saas_metadata: StoredConversationMetadataSaas | None = None,
sub_conversation_ids: list[UUID] | None = None,
) -> SharedConversation:
"""Convert StoredConversationMetadata to SharedConversation."""
"""Convert StoredConversationMetadata to SharedConversation.
Args:
stored: The base conversation metadata from conversation_metadata table.
saas_metadata: Optional SAAS metadata containing user_id and org_id.
sub_conversation_ids: Optional list of sub-conversation IDs.
"""
# V1 conversations should always have a sandbox_id
sandbox_id = stored.sandbox_id
assert sandbox_id is not None
@@ -239,9 +270,16 @@ class SQLSharedConversationInfoService(SharedConversationInfoService):
created_at = self._fix_timezone(stored.created_at)
updated_at = self._fix_timezone(stored.last_updated_at)
# Get user_id from SAAS metadata if available
created_by_user_id = (
str(saas_metadata.user_id)
if saas_metadata and saas_metadata.user_id
else None
)
return SharedConversation(
id=UUID(stored.conversation_id),
created_by_user_id=None, # user_id is no longer stored in conversation metadata
created_by_user_id=created_by_user_id,
sandbox_id=stored.sandbox_id,
selected_repository=stored.selected_repository,
selected_branch=stored.selected_branch,
+23
View File
@@ -98,6 +98,29 @@ def decrypt_legacy_value(value: str | SecretStr) -> str:
return get_fernet().decrypt(b64decode(value.encode())).decode()
def encrypt_legacy_model(encrypt_keys: list, model_instance) -> dict:
return encrypt_legacy_kwargs(encrypt_keys, model_to_kwargs(model_instance))
def encrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
if value is None:
continue
if key in encrypt_keys:
value = encrypt_legacy_value(value)
kwargs[key] = value
return kwargs
def encrypt_legacy_value(value: str | SecretStr) -> str:
if isinstance(value, SecretStr):
return b64encode(
get_fernet().encrypt(value.get_secret_value().encode())
).decode()
else:
return b64encode(get_fernet().encrypt(value.encode())).decode()
def get_fernet():
global _fernet
if _fernet is None:
+222 -1
View File
@@ -96,7 +96,7 @@ class LiteLlmManager:
user_settings: UserSettings,
) -> UserSettings | None:
logger.info(
'SettingsStore:umigrate_lite_llm_entries:start',
'LiteLlmManager:migrate_lite_llm_entries:start',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
@@ -141,19 +141,35 @@ class LiteLlmManager:
return None
credits = max(max_budget - spend, 0.0)
logger.debug(
'LiteLlmManager:migrate_lite_llm_entries:create_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._create_team(
client, keycloak_user_id, org_id, credits
)
logger.debug(
'LiteLlmManager:migrate_lite_llm_entries:update_user',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_user(
client, keycloak_user_id, max_budget=UNLIMITED_BUDGET_SETTING
)
logger.debug(
'LiteLlmManager:migrate_lite_llm_entries:add_user_to_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, org_id, credits
)
if user_settings.llm_api_key:
logger.debug(
'LiteLlmManager:migrate_lite_llm_entries:update_key',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_key(
client,
keycloak_user_id,
@@ -162,6 +178,10 @@ class LiteLlmManager:
)
if user_settings.llm_api_key_for_byor:
logger.debug(
'LiteLlmManager:migrate_lite_llm_entries:update_byor_key',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_key(
client,
keycloak_user_id,
@@ -169,6 +189,167 @@ class LiteLlmManager:
team_id=org_id,
)
logger.info(
'LiteLlmManager:migrate_lite_llm_entries:complete',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
return user_settings
@staticmethod
async def downgrade_entries(
org_id: str,
keycloak_user_id: str,
user_settings: UserSettings,
) -> UserSettings | None:
"""Downgrade a migrated user's LiteLLM entries back to the pre-migration state.
This reverses the migrate_entries operation:
1. Get the user max budget from their org team in litellm
2. Set the max budget in the user in litellm (restore from team)
3. Add the user back to the default team in litellm
4. Update keys to remove org team association
5. Remove the user from their org team in litellm
6. Delete the user org team in litellm
Note: The database changes (already_migrated flag, org/org_member deletion)
should be handled separately by the caller.
Args:
org_id: The organization ID (which is also the team_id in litellm)
keycloak_user_id: The user's Keycloak ID
user_settings: The user's settings object
Returns:
The user_settings if downgrade was successful, None otherwise
"""
logger.info(
'LiteLlmManager:downgrade_entries:start',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
if not local_deploy:
async with httpx.AsyncClient(
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
# Step 1: Get the team info to retrieve the budget
logger.debug(
'LiteLlmManager:downgrade_entries:get_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
team_info = await LiteLlmManager._get_team(client, org_id)
if not team_info:
logger.error(
'LiteLlmManager:downgrade_entries:team_not_found',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
return None
# Get team budget (max_budget) and spend to calculate current credits
team_data = team_info.get('team_info', {})
max_budget = team_data.get('max_budget', 0.0)
spend = team_data.get('spend', 0.0)
# Get user membership info for budget in team
user_membership = await LiteLlmManager._get_user_team_info(
client, keycloak_user_id, org_id
)
if user_membership:
# Use user's budget in team if available
user_max_budget_in_team = user_membership.get('max_budget_in_team')
user_spend_in_team = user_membership.get('spend', 0.0)
if user_max_budget_in_team is not None:
max_budget = user_max_budget_in_team
spend = user_spend_in_team
# Calculate total budget to restore (credits + spend = max_budget)
# We restore the full max_budget that was on the team/user-in-team
restored_budget = max_budget if max_budget else 0.0
logger.debug(
'LiteLlmManager:downgrade_entries:budget_info',
extra={
'org_id': org_id,
'user_id': keycloak_user_id,
'max_budget': max_budget,
'spend': spend,
'restored_budget': restored_budget,
},
)
# Step 2: Update user to set their max_budget back from unlimited
logger.debug(
'LiteLlmManager:downgrade_entries:update_user',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_user(
client, keycloak_user_id, max_budget=restored_budget, spend=spend
)
# Step 3: Add user back to the default team
if LITE_LLM_TEAM_ID:
logger.debug(
'LiteLlmManager:downgrade_entries:add_to_default_team',
extra={
'org_id': org_id,
'user_id': keycloak_user_id,
'default_team_id': LITE_LLM_TEAM_ID,
},
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, LITE_LLM_TEAM_ID, restored_budget
)
# Step 4: Update keys to remove org team association (set team_id to default)
if user_settings.llm_api_key:
logger.debug(
'LiteLlmManager:downgrade_entries:update_key',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_key(
client,
keycloak_user_id,
user_settings.llm_api_key,
team_id=LITE_LLM_TEAM_ID,
)
if user_settings.llm_api_key_for_byor:
logger.debug(
'LiteLlmManager:downgrade_entries:update_byor_key',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._update_key(
client,
keycloak_user_id,
user_settings.llm_api_key_for_byor,
team_id=LITE_LLM_TEAM_ID,
)
# Step 5: Remove user from their org team
logger.debug(
'LiteLlmManager:downgrade_entries:remove_from_org_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._remove_user_from_team(
client, keycloak_user_id, org_id
)
# Step 6: Delete the org team
logger.debug(
'LiteLlmManager:downgrade_entries:delete_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._delete_team(client, org_id)
logger.info(
'LiteLlmManager:downgrade_entries:complete',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
return user_settings
@staticmethod
@@ -613,6 +794,45 @@ class LiteLlmManager:
)
response.raise_for_status()
@staticmethod
async def _remove_user_from_team(
client: httpx.AsyncClient,
keycloak_user_id: str,
team_id: str,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/team/member_delete',
json={
'team_id': team_id,
'user_id': keycloak_user_id,
},
)
if not response.is_success:
if response.status_code == 404:
# User not in team, that's fine for downgrade
logger.info(
'User not in team during removal',
extra={'user_id': keycloak_user_id, 'team_id': team_id},
)
return
logger.error(
'error_removing_litellm_user_from_team',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': keycloak_user_id,
'team_id': team_id,
},
)
response.raise_for_status()
logger.info(
'LiteLlmManager:_remove_user_from_team:user_removed',
extra={'user_id': keycloak_user_id, 'team_id': team_id},
)
@staticmethod
async def _generate_key(
client: httpx.AsyncClient,
@@ -856,6 +1076,7 @@ class LiteLlmManager:
delete_user = staticmethod(with_http_client(_delete_user))
delete_team = staticmethod(with_http_client(_delete_team))
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
remove_user_from_team = staticmethod(with_http_client(_remove_user_from_team))
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
generate_key = staticmethod(with_http_client(_generate_key))
+401
View File
@@ -9,8 +9,11 @@ from uuid import UUID as parse_uuid
from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model
from server.routes.org_models import (
LiteLLMIntegrationError,
OrgAuthorizationError,
OrgDatabaseError,
OrgNameExistsError,
OrgNotFoundError,
OrgUpdate,
)
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
@@ -393,6 +396,224 @@ class OrgService:
)
return e
@staticmethod
def has_admin_or_owner_role(user_id: str, org_id: UUID) -> bool:
"""
Check if user has admin or owner role in the specified organization.
Args:
user_id: User ID to check
org_id: Organization ID to check membership in
Returns:
bool: True if user has admin or owner role, False otherwise
"""
try:
# Parse user_id as UUID for database query
user_uuid = parse_uuid(user_id)
# Get the user's membership in this organization
# Note: The type annotation says int but the actual column is UUID
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
if not org_member:
return False
# Get the role details
role = RoleStore.get_role_by_id(org_member.role_id)
if not role:
return False
# Admin and owner roles have elevated permissions
# Based on test files, both admin and owner have rank 1
return role.name in ['admin', 'owner']
except Exception as e:
logger.warning(
'Error checking user role in organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
'error': str(e),
},
)
return False
@staticmethod
def is_org_member(user_id: str, org_id: UUID) -> bool:
"""
Check if user is a member of the specified organization.
Args:
user_id: User ID to check
org_id: Organization ID to check membership in
Returns:
bool: True if user is a member, False otherwise
"""
try:
user_uuid = parse_uuid(user_id)
org_member = OrgMemberStore.get_org_member(org_id, user_uuid)
return org_member is not None
except Exception as e:
logger.warning(
'Error checking user membership in organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
'error': str(e),
},
)
return False
@staticmethod
def _get_llm_settings_fields() -> set[str]:
"""
Get the set of organization fields that are considered LLM settings
and require admin/owner role to update.
Returns:
set[str]: Set of field names that require elevated permissions
"""
return {
'default_llm_model',
'default_llm_api_key_for_byor',
'default_llm_base_url',
'search_api_key',
'security_analyzer',
'agent',
'confirmation_mode',
'enable_default_condenser',
'condenser_max_size',
}
@staticmethod
def _has_llm_settings_updates(update_data: OrgUpdate) -> set[str]:
"""
Check if the update contains any LLM settings fields.
Args:
update_data: The organization update data
Returns:
set[str]: Set of LLM fields being updated (empty if none)
"""
llm_fields = OrgService._get_llm_settings_fields()
update_dict = update_data.model_dump(exclude_none=True)
return llm_fields.intersection(update_dict.keys())
@staticmethod
async def update_org_with_permissions(
org_id: UUID,
update_data: OrgUpdate,
user_id: str,
) -> Org:
"""
Update organization with permission checks for LLM settings.
Args:
org_id: Organization UUID to update
update_data: Organization update data from request
user_id: ID of the user requesting the update
Returns:
Org: The updated organization object
Raises:
ValueError: If organization not found
PermissionError: If user is not a member, or lacks admin/owner role for LLM settings
OrgDatabaseError: If database update fails
"""
logger.info(
'Updating organization with permission checks',
extra={
'org_id': str(org_id),
'user_id': user_id,
'has_update_data': update_data is not None,
},
)
# Validate organization exists
existing_org = OrgStore.get_org_by_id(org_id)
if not existing_org:
raise ValueError(f'Organization with ID {org_id} not found')
# Check if user is a member of this organization
if not OrgService.is_org_member(user_id, org_id):
logger.warning(
'Non-member attempted to update organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
raise PermissionError(
'User must be a member of the organization to update it'
)
# Check if update contains any LLM settings
llm_fields_being_updated = OrgService._has_llm_settings_updates(update_data)
if llm_fields_being_updated:
# Verify user has admin or owner role
has_permission = OrgService.has_admin_or_owner_role(user_id, org_id)
if not has_permission:
logger.warning(
'User attempted to update LLM settings without permission',
extra={
'user_id': user_id,
'org_id': str(org_id),
'attempted_fields': list(llm_fields_being_updated),
},
)
raise PermissionError(
'Admin or owner role required to update LLM settings'
)
logger.debug(
'User has permission to update LLM settings',
extra={
'user_id': user_id,
'org_id': str(org_id),
'llm_fields': list(llm_fields_being_updated),
},
)
# Convert to dict for OrgStore (excluding None values)
update_dict = update_data.model_dump(exclude_none=True)
if not update_dict:
logger.info(
'No fields to update',
extra={'org_id': str(org_id), 'user_id': user_id},
)
return existing_org
# Perform the update
try:
updated_org = OrgStore.update_org(org_id, update_dict)
if not updated_org:
raise OrgDatabaseError('Failed to update organization in database')
logger.info(
'Organization updated successfully',
extra={
'org_id': str(org_id),
'user_id': user_id,
'updated_fields': list(update_dict.keys()),
},
)
return updated_org
except Exception as e:
logger.error(
'Failed to update organization',
extra={
'org_id': str(org_id),
'user_id': user_id,
'error': str(e),
},
)
raise OrgDatabaseError(f'Failed to update organization: {str(e)}')
@staticmethod
async def get_org_credits(user_id: str, org_id: UUID) -> float | None:
"""
@@ -441,3 +662,183 @@ class OrgService:
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
return None
@staticmethod
def get_user_orgs_paginated(
user_id: str, page_id: str | None = None, limit: int = 100
):
"""
Get paginated list of organizations for a user.
Args:
user_id: User ID (string that will be converted to UUID)
page_id: Optional page ID (offset as string) for pagination
limit: Maximum number of organizations to return
Returns:
Tuple of (list of Org objects, next_page_id or None)
"""
logger.debug(
'Fetching paginated organizations for user',
extra={'user_id': user_id, 'page_id': page_id, 'limit': limit},
)
# Convert user_id string to UUID
user_uuid = parse_uuid(user_id)
# Fetch organizations from store
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_uuid, page_id=page_id, limit=limit
)
logger.debug(
'Retrieved organizations for user',
extra={
'user_id': user_id,
'org_count': len(orgs),
'has_more': next_page_id is not None,
},
)
return orgs, next_page_id
@staticmethod
async def get_org_by_id(org_id: UUID, user_id: str) -> Org:
"""
Get organization by ID with membership validation.
This method verifies that the user is a member of the organization
before returning the organization details.
Args:
org_id: Organization ID
user_id: User ID (string that will be converted to UUID)
Returns:
Org: The organization object
Raises:
OrgNotFoundError: If organization not found or user is not a member
"""
logger.info(
'Retrieving organization',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
# Verify user is a member of the organization
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
if not org_member:
logger.warning(
'User is not a member of organization or organization does not exist',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise OrgNotFoundError(str(org_id))
# Retrieve organization
org = OrgStore.get_org_by_id(org_id)
if not org:
logger.error(
'Organization not found despite valid membership',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise OrgNotFoundError(str(org_id))
logger.info(
'Successfully retrieved organization',
extra={
'org_id': str(org.id),
'org_name': org.name,
'user_id': user_id,
},
)
return org
@staticmethod
def verify_owner_authorization(user_id: str, org_id: UUID) -> None:
"""
Verify that the user is the owner of the organization.
Args:
user_id: User ID to check
org_id: Organization ID
Raises:
OrgNotFoundError: If organization doesn't exist
OrgAuthorizationError: If user is not authorized to delete
"""
# Check if organization exists
org = OrgStore.get_org_by_id(org_id)
if not org:
raise OrgNotFoundError(str(org_id))
# Check if user is a member of the organization
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
if not org_member:
raise OrgAuthorizationError('User is not a member of this organization')
# Check if user has owner role
role = RoleStore.get_role_by_id(org_member.role_id)
if not role or role.name != 'owner':
raise OrgAuthorizationError(
'Only organization owners can delete organizations'
)
logger.debug(
'User authorization verified for organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id), 'role': role.name},
)
@staticmethod
async def delete_org_with_cleanup(user_id: str, org_id: UUID) -> Org:
"""
Delete organization with complete cleanup of all associated data.
This method performs the complete organization deletion workflow:
1. Verifies user authorization (owner only)
2. Performs database cascade deletion and LiteLLM cleanup in single transaction
Args:
user_id: User ID requesting deletion (must be owner)
org_id: Organization ID to delete
Returns:
Org: The deleted organization details
Raises:
OrgNotFoundError: If organization doesn't exist
OrgAuthorizationError: If user is not authorized to delete
OrgDatabaseError: If database operations or LiteLLM cleanup fail
"""
logger.info(
'Starting organization deletion',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
# Step 1: Verify user authorization
OrgService.verify_owner_authorization(user_id, org_id)
# Step 2: Perform database cascade deletion with LiteLLM cleanup in transaction
try:
deleted_org = await OrgStore.delete_org_cascade(org_id)
if not deleted_org:
# This shouldn't happen since we verified existence above
raise OrgDatabaseError('Organization not found during deletion')
logger.info(
'Organization deletion completed successfully',
extra={
'user_id': user_id,
'org_id': str(org_id),
'org_name': deleted_org.name,
},
)
return deleted_org
except Exception as e:
logger.error(
'Organization deletion failed',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')
+175
View File
@@ -10,8 +10,10 @@ from server.constants import (
ORG_SETTINGS_VERSION,
get_default_litellm_model,
)
from sqlalchemy import text
from sqlalchemy.orm import joinedload
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
from storage.org_member import OrgMember
from storage.user import User
@@ -96,6 +98,63 @@ class OrgStore:
orgs = session.query(Org).all()
return orgs
@staticmethod
def get_user_orgs_paginated(
user_id: UUID, page_id: str | None = None, limit: int = 100
) -> tuple[list[Org], str | None]:
"""
Get paginated list of organizations for a user.
Args:
user_id: User UUID
page_id: Optional page ID (offset as string) for pagination
limit: Maximum number of organizations to return
Returns:
Tuple of (list of Org objects, next_page_id or None)
"""
with session_maker() as session:
# Build query joining OrgMember with Org
query = (
session.query(Org)
.join(OrgMember, Org.id == OrgMember.org_id)
.filter(OrgMember.user_id == user_id)
.order_by(Org.name)
)
# Apply pagination offset
if page_id is not None:
try:
offset = int(page_id)
query = query.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Fetch limit + 1 to check if there are more results
query = query.limit(limit + 1)
orgs = query.all()
# Check if there are more results
has_more = len(orgs) > limit
if has_more:
orgs = orgs[:limit]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
# Validate org versions
validated_orgs = [
OrgStore._validate_org_version(org) for org in orgs if org
]
validated_orgs = [org for org in validated_orgs if org is not None]
return validated_orgs, next_page_id
@staticmethod
def update_org(
org_id: UUID,
@@ -186,3 +245,119 @@ class OrgStore:
session.commit()
session.refresh(org)
return org
@staticmethod
async def delete_org_cascade(org_id: UUID) -> Org | None:
"""
Delete organization and all associated data in cascade, including external LiteLLM cleanup.
Args:
org_id: UUID of the organization to delete
Returns:
Org: The deleted organization object, or None if not found
Raises:
Exception: If database operations or LiteLLM cleanup fail
"""
with session_maker() as session:
# First get the organization to return it
org = session.query(Org).filter(Org.id == org_id).first()
if not org:
return None
try:
# 1. Delete conversation data for organization conversations
session.execute(
text("""
DELETE FROM conversation_metadata
WHERE conversation_id IN (
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
)
"""),
{'org_id': str(org_id)},
)
session.execute(
text("""
DELETE FROM app_conversation_start_task
WHERE app_conversation_id::text IN (
SELECT conversation_id FROM conversation_metadata_saas WHERE org_id = :org_id
)
"""),
{'org_id': str(org_id)},
)
# 2. Delete organization-owned data tables (direct org_id foreign keys)
session.execute(
text('DELETE FROM billing_sessions WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text(
'DELETE FROM conversation_metadata_saas WHERE org_id = :org_id'
),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM custom_secrets WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM api_keys WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM slack_conversation WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM slack_users WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
session.execute(
text('DELETE FROM stripe_customers WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
# 3. Delete organization memberships
session.execute(
text('DELETE FROM org_member WHERE org_id = :org_id'),
{'org_id': str(org_id)},
)
# 4. Handle users with this as current_org_id
session.execute(
text(
'UPDATE "user" SET current_org_id = NULL WHERE current_org_id = :org_id'
),
{'org_id': str(org_id)},
)
# 5. Finally delete the organization
session.delete(org)
# 6. Clean up LiteLLM team before committing transaction
logger.info(
'Deleting LiteLLM team within database transaction',
extra={'org_id': str(org_id)},
)
await LiteLlmManager.delete_team(str(org_id))
# 7. Commit all changes only if everything succeeded
session.commit()
logger.info(
'Successfully deleted organization and all associated data including LiteLLM team',
extra={'org_id': str(org_id), 'org_name': org.name},
)
return org
except Exception as e:
session.rollback()
logger.error(
'Failed to delete organization - transaction rolled back',
extra={'org_id': str(org_id), 'error': str(e)},
)
raise
+17 -1
View File
@@ -4,7 +4,9 @@ Store class for managing roles.
from typing import List, Optional
from storage.database import session_maker
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from storage.database import a_session_maker, session_maker
from storage.role import Role
@@ -33,6 +35,20 @@ class RoleStore:
with session_maker() as session:
return session.query(Role).filter(Role.name == name).first()
@staticmethod
async def get_role_by_name_async(
name: str,
session: Optional[AsyncSession] = None,
) -> Optional[Role]:
"""Get role by name."""
if session is not None:
result = await session.execute(select(Role).where(Role.name == name))
return result.scalars().first()
async with a_session_maker() as session:
result = await session.execute(select(Role).where(Role.name == name))
return result.scalars().first()
@staticmethod
def list_roles() -> List[Role]:
"""List all roles."""
+2
View File
@@ -31,6 +31,8 @@ class User(Base): # type: ignore
user_consents_to_analytics = Column(Boolean, nullable=True)
email = Column(String, nullable=True)
email_verified = Column(Boolean, nullable=True)
git_user_name = Column(String, nullable=True)
git_user_email = Column(String, nullable=True)
# Relationships
role = relationship('Role', back_populates='users')
+400 -19
View File
@@ -14,10 +14,13 @@ from server.constants import (
get_default_litellm_model,
)
from server.logger import logger
from sqlalchemy import text
from sqlalchemy import select, text
from sqlalchemy.orm import joinedload
from storage.database import session_maker
from storage.encrypt_utils import decrypt_legacy_model
from storage.database import a_session_maker, session_maker
from storage.encrypt_utils import (
decrypt_legacy_model,
encrypt_legacy_value,
)
from storage.org import Org
from storage.org_member import OrgMember
from storage.role_store import RoleStore
@@ -116,7 +119,7 @@ class UserStore:
redis_client = UserStore._get_redis_client()
if redis_client is None:
logger.warning(
'saas_settings_store:_acquire_user_creation_lock:no_redis_client',
'user_store:_acquire_user_creation_lock:no_redis_client',
extra={'user_id': user_id},
)
return True # Proceed without locking if Redis is unavailable
@@ -159,12 +162,20 @@ class UserStore:
from storage.lite_llm_manager import LiteLlmManager
logger.debug(
'user_store:migrate_user:calling_litellm_migrate_entries',
extra={'user_id': user_id},
)
await LiteLlmManager.migrate_entries(
str(org.id),
user_id,
decrypted_user_settings,
)
logger.debug(
'user_store:migrate_user:done_litellm_migrate_entries',
extra={'user_id': user_id},
)
custom_settings = UserStore._has_custom_settings(
decrypted_user_settings, user_settings.user_version
)
@@ -172,7 +183,15 @@ class UserStore:
# avoids circular reference. This migrate method is temprorary until all users are migrated.
from integrations.stripe_service import migrate_customer
logger.debug(
'user_store:migrate_user:calling_stripe_migrate_customer',
extra={'user_id': user_id},
)
await migrate_customer(session, user_id, org)
logger.debug(
'user_store:migrate_user:done_stripe_migrate_customer',
extra={'user_id': user_id},
)
from storage.org_store import OrgStore
@@ -201,7 +220,15 @@ class UserStore:
)
session.add(user)
role = RoleStore.get_role_by_name('owner')
logger.debug(
'user_store:migrate_user:calling_get_role_by_name',
extra={'user_id': user_id},
)
role = await RoleStore.get_role_by_name_async('owner')
logger.debug(
'user_store:migrate_user:done_get_role_by_name',
extra={'user_id': user_id},
)
from storage.org_member_store import OrgMemberStore
@@ -214,7 +241,6 @@ class UserStore:
if not custom_settings:
del org_member_kwargs['llm_model']
del org_member_kwargs['llm_base_url']
del org_member_kwargs['llm_api_key_for_byor']
org_member = OrgMember(
org_id=org.id,
@@ -229,6 +255,10 @@ class UserStore:
user_settings.already_migrated = True
session.merge(user_settings)
session.flush()
logger.debug(
'user_store:migrate_user:session_flush_complete',
extra={'user_id': user_id},
)
# need to migrate conversation metadata
session.execute(
@@ -296,8 +326,262 @@ class UserStore:
session.commit()
session.refresh(user)
user.org_members # load org_members
logger.debug(
'user_store:migrate_user:session_committed',
extra={'user_id': user_id},
)
return user
@staticmethod
async def downgrade_user(user_id: str) -> UserSettings | None:
"""Downgrade a migrated user back to the pre-migration state.
This reverses the migrate_user operation:
1. Get the user's settings from user_settings table (migrated users) or
create new user_settings from org_members table (new sign-ups)
2. Call LiteLlmManager.downgrade_entries to revert LiteLLM state
3. Copy user_id from conversation_metadata_saas to conversation_metadata
4. Delete conversation_metadata_saas entries
5. Reset org_id columns in related tables (stripe_customers, slack_users, etc.)
6. Delete the org_member and org entries
7. Delete the user entry
8. Set already_migrated=False on user_settings
For new sign-ups (users who registered after migration was deployed),
there won't be an existing user_settings entry. In this case, we fall back
to the org_members table to get the user's API keys and settings, and create
a new user_settings entry for them.
Args:
user_id: The Keycloak user ID to downgrade
Returns:
The user_settings if downgrade was successful, None otherwise.
Returns None if the org has multiple members (not a personal org).
"""
logger.info(
'user_store:downgrade_user:start',
extra={'user_id': user_id},
)
with session_maker() as session:
# Get the user and their org_member
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
if not user:
logger.warning(
'user_store:downgrade_user:user_not_found',
extra={'user_id': user_id},
)
return None
# Get the user's personal org (org_id == user_id)
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
if not org:
logger.warning(
'user_store:downgrade_user:org_not_found',
extra={'user_id': user_id},
)
return None
# Get the user_settings (for migrated users)
user_settings = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(True),
)
.first()
)
# For new sign-ups after migration, user_settings won't exist
# Fall back to getting data from org_members
is_new_signup = False
if not user_settings:
logger.info(
'user_store:downgrade_user:user_settings_not_found_checking_org_members',
extra={'user_id': user_id},
)
# Get org_members for this org - should only be one for personal orgs
org_members = (
session.query(OrgMember).filter(OrgMember.org_id == org.id).all()
)
if len(org_members) != 1:
logger.error(
'user_store:downgrade_user:unexpected_org_members_count',
extra={
'user_id': user_id,
'org_id': str(org.id),
'org_members_count': len(org_members),
},
)
return None
org_member = org_members[0]
is_new_signup = True
# Create a new user_settings entry from OrgMember, User, and Org data
# This is needed for new sign-ups who don't have user_settings
user_settings = UserStore._create_user_settings_from_entities(
user_id, org_member, user, org
)
session.add(user_settings)
session.flush()
logger.info(
'user_store:downgrade_user:created_user_settings_from_org_member',
extra={'user_id': user_id},
)
# Call LiteLLM downgrade
from storage.lite_llm_manager import LiteLlmManager
logger.debug(
'user_store:downgrade_user:calling_litellm_downgrade_entries',
extra={'user_id': user_id},
)
# Get the API keys for LiteLLM downgrade
if is_new_signup:
# For new signups, we already have decrypted values in user_settings
decrypted_user_settings = user_settings
else:
# For migrated users, decrypt the legacy model
kwargs = decrypt_legacy_model(
[
'llm_api_key',
'llm_api_key_for_byor',
'search_api_key',
'sandbox_api_key',
],
user_settings,
)
decrypted_user_settings = UserSettings(**kwargs)
await LiteLlmManager.downgrade_entries(
str(org.id),
user_id,
decrypted_user_settings,
)
logger.debug(
'user_store:downgrade_user:done_litellm_downgrade_entries',
extra={'user_id': user_id},
)
user_uuid = uuid.UUID(user_id)
# Step 3: Copy user_id from conversation_metadata_saas to conversation_metadata
# This ensures any conversations created after migration have their user_id
# preserved in the original table before we delete the saas entries
session.execute(
text("""
UPDATE conversation_metadata
SET user_id = :user_id
WHERE conversation_id IN (
SELECT conversation_id
FROM conversation_metadata_saas
WHERE user_id = :user_uuid
)
"""),
{'user_id': user_id, 'user_uuid': user_uuid},
)
# Step 4: Delete conversation_metadata_saas entries
session.execute(
text('DELETE FROM conversation_metadata_saas WHERE user_id = :user_id'),
{'user_id': user_uuid},
)
# Step 5: Reset org_id columns in related tables
# Reset stripe_customers
session.execute(
text(
'UPDATE stripe_customers SET org_id = NULL WHERE org_id = :org_id'
),
{'org_id': user_uuid},
)
# Reset slack_users
session.execute(
text('UPDATE slack_users SET org_id = NULL WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Reset slack_conversation
session.execute(
text(
'UPDATE slack_conversation SET org_id = NULL WHERE org_id = :org_id'
),
{'org_id': user_uuid},
)
# Reset api_keys
session.execute(
text('UPDATE api_keys SET org_id = NULL WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Reset custom_secrets
session.execute(
text('UPDATE custom_secrets SET org_id = NULL WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Reset billing_sessions
session.execute(
text(
'UPDATE billing_sessions SET org_id = NULL WHERE org_id = :org_id'
),
{'org_id': user_uuid},
)
# Step 6: Delete org_member entries for this org
session.execute(
text('DELETE FROM org_member WHERE org_id = :org_id'),
{'org_id': user_uuid},
)
# Step 7: Delete the user entry
session.execute(
text('DELETE FROM "user" WHERE id = :user_id'),
{'user_id': user_uuid},
)
# Delete the org entry
session.execute(
text('DELETE FROM org WHERE id = :org_id'),
{'org_id': user_uuid},
)
# Step 8: Set already_migrated=False on user_settings and encrypt fields
user_settings.already_migrated = False
# Re-encrypt the sensitive fields before storing in the DB
encrypt_keys = [
'llm_api_key',
'llm_api_key_for_byor',
'search_api_key',
'sandbox_api_key',
]
for key in encrypt_keys:
value = getattr(user_settings, key, None)
if value is not None:
setattr(user_settings, key, encrypt_legacy_value(value))
session.merge(user_settings)
session.commit()
logger.info(
'user_store:downgrade_user:complete',
extra={'user_id': user_id},
)
return user_settings
@staticmethod
def get_user_by_id(user_id: str) -> Optional[User]:
"""Get user by Keycloak user ID (sync version).
@@ -322,7 +606,7 @@ class UserStore:
):
# The user is already being created in another thread / process
logger.info(
'saas_settings_store:create_default_settings:waiting_for_lock',
'user_store:create_default_settings:waiting_for_lock',
extra={'user_id': user_id},
)
call_async_from_sync(
@@ -372,13 +656,13 @@ class UserStore:
This is the preferred method when calling from an async context as it
avoids event loop conflicts that can occur with the sync version.
"""
with session_maker() as session:
user = (
session.query(User)
async with a_session_maker() as session:
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
user = result.scalars().first()
if user:
return user
@@ -386,32 +670,39 @@ class UserStore:
while not await UserStore._acquire_user_creation_lock(user_id):
# The user is already being created in another thread / process
logger.info(
'saas_settings_store:create_default_settings:waiting_for_lock',
'user_store:get_user_by_id_async:waiting_for_lock',
extra={'user_id': user_id},
)
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
# Check for user again as migration could have happened while trying to get the lock.
user = (
session.query(User)
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
user = result.scalars().first()
if user:
return user
user_settings = (
session.query(UserSettings)
.filter(
logger.info(
'user_store:get_user_by_id_async:start_migration',
extra={'user_id': user_id},
)
result = await session.execute(
select(UserSettings).filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(False),
)
.first()
)
user_settings = result.scalars().first()
if user_settings:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(user_id)
logger.info(
'user_store:get_user_by_id_async:calling_migrate_user',
extra={'user_id': user_id},
)
user = await UserStore.migrate_user(
user_id,
user_settings,
@@ -481,6 +772,96 @@ class UserStore:
}
return kwargs
@staticmethod
def _create_user_settings_from_entities(
user_id: str, org_member: OrgMember, user: User, org: Org
) -> UserSettings:
"""Create UserSettings from OrgMember, User, and Org data.
Uses OrgMember values first. If an OrgMember field is None and there's
a corresponding "default_" field in Org, use the Org value.
Also pulls relevant fields from User.
Args:
user_id: The Keycloak user ID
org_member: The OrgMember entity
user: The User entity
org: The Org entity
Returns:
A new UserSettings object populated from the entities
"""
# Mapping from OrgMember fields to corresponding Org "default_" fields
org_member_to_org_default = {
'llm_model': 'default_llm_model',
'llm_base_url': 'default_llm_base_url',
'max_iterations': 'default_max_iterations',
}
def get_value_with_org_fallback(field_name: str, org_member_value):
"""Get value from OrgMember, falling back to Org default if None."""
if org_member_value is not None:
return org_member_value
org_default_field = org_member_to_org_default.get(field_name)
if org_default_field and hasattr(org, org_default_field):
return getattr(org, org_default_field)
return None
# Get values from OrgMember with Org fallback for fields with default_ prefix
llm_model = get_value_with_org_fallback('llm_model', org_member.llm_model)
llm_base_url = get_value_with_org_fallback(
'llm_base_url', org_member.llm_base_url
)
max_iterations = get_value_with_org_fallback(
'max_iterations', org_member.max_iterations
)
return UserSettings(
keycloak_user_id=user_id,
# OrgMember fields
llm_api_key=org_member.llm_api_key.get_secret_value()
if org_member.llm_api_key
else None,
llm_api_key_for_byor=org_member.llm_api_key_for_byor.get_secret_value()
if org_member.llm_api_key_for_byor
else None,
llm_model=llm_model,
llm_base_url=llm_base_url,
max_iterations=max_iterations,
# User fields
accepted_tos=user.accepted_tos,
enable_sound_notifications=user.enable_sound_notifications,
language=user.language,
user_consents_to_analytics=user.user_consents_to_analytics,
email=user.email,
email_verified=user.email_verified,
git_user_name=user.git_user_name,
git_user_email=user.git_user_email,
# Org fields
agent=org.agent,
security_analyzer=org.security_analyzer,
confirmation_mode=org.confirmation_mode,
remote_runtime_resource_factor=org.remote_runtime_resource_factor,
enable_default_condenser=org.enable_default_condenser,
billing_margin=org.billing_margin,
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters,
sandbox_base_container_image=org.sandbox_base_container_image,
sandbox_runtime_container_image=org.sandbox_runtime_container_image,
user_version=org.org_version,
mcp_config=org.mcp_config,
search_api_key=org.search_api_key.get_secret_value()
if org.search_api_key
else None,
sandbox_api_key=org.sandbox_api_key.get_secret_value()
if org.sandbox_api_key
else None,
max_budget_per_task=org.max_budget_per_task,
enable_solvability_analysis=org.enable_solvability_analysis,
v1_enabled=org.v1_enabled,
condenser_max_size=org.condenser_max_size,
already_migrated=False,
)
@staticmethod
def _has_custom_settings(
user_settings: UserSettings, old_user_version: int | None
@@ -0,0 +1,96 @@
from unittest.mock import patch
from urllib.parse import parse_qs, urlparse
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import SecretStr
from server.routes.github_proxy import add_github_proxy_routes
@pytest.fixture
def app_with_github_proxy(monkeypatch):
"""Create a FastAPI app with github proxy routes enabled."""
# Enable the github proxy endpoints
monkeypatch.setenv('GITHUB_PROXY_ENDPOINTS', '1')
# Mock the config to have a jwt_secret
mock_config = type(
'MockConfig', (), {'jwt_secret': SecretStr('test-secret-key-for-testing')}
)()
app = FastAPI()
with patch('server.routes.github_proxy.GITHUB_PROXY_ENDPOINTS', True):
with patch('server.routes.github_proxy.config', mock_config):
add_github_proxy_routes(app)
# Return app and mock_config so we can use the same config in tests
return app, mock_config
def test_state_compress_encrypt_and_decrypt_decompress_roundtrip(
app_with_github_proxy, monkeypatch
):
"""
Verify the code path used by github_proxy_start -> github_proxy_callback:
- compress payload, encrypt, base64-encode (what the start code does)
- base64-decode, decrypt, decompress (what the callback code does)
This test exercises the actual endpoints to verify the roundtrip works correctly.
"""
app, mock_config = app_with_github_proxy
client = TestClient(app)
original_state = 'some-state-value'
original_redirect_uri = 'https://example.com/redirect'
# Call github_proxy_start endpoint - it should redirect to GitHub with encrypted state
with patch('server.routes.github_proxy.config', mock_config):
response = client.get(
'/github-proxy/test-subdomain/login/oauth/authorize',
params={
'state': original_state,
'redirect_uri': original_redirect_uri,
'client_id': 'test-client-id',
},
follow_redirects=False,
)
assert response.status_code == 307
redirect_url = response.headers['location']
# Verify it redirects to GitHub
assert redirect_url.startswith('https://github.com/login/oauth/authorize')
# Parse the redirect URL to get the encrypted state
parsed = urlparse(redirect_url)
query_params = parse_qs(parsed.query)
encrypted_state = query_params['state'][0]
# The redirect_uri should now point to our callback
assert 'github-proxy/callback' in query_params['redirect_uri'][0]
# Now simulate GitHub calling back with this encrypted state
with patch('server.routes.github_proxy.config', mock_config):
callback_response = client.get(
'/github-proxy/callback',
params={
'state': encrypted_state,
'code': 'test-auth-code',
},
follow_redirects=False,
)
assert callback_response.status_code == 307
final_redirect = callback_response.headers['location']
# Verify the callback redirects to the original redirect_uri
assert final_redirect.startswith(original_redirect_uri)
# Parse the final redirect to verify the state was decrypted correctly
final_parsed = urlparse(final_redirect)
final_params = parse_qs(final_parsed.query)
assert final_params['state'][0] == original_state
assert final_params['code'][0] == 'test-auth-code'
File diff suppressed because it is too large Load Diff
+10 -10
View File
@@ -163,7 +163,7 @@ async def test_create_checkout_session_stripe_error(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('server.routes.billing.validate_saas_environment'),
patch('server.routes.billing.validate_billing_enabled'),
):
await create_checkout_session(
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
@@ -204,7 +204,7 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('server.routes.billing.validate_saas_environment'),
patch('server.routes.billing.validate_billing_enabled'),
):
mock_db_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_db_session
@@ -236,8 +236,8 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
mode='payment',
payment_method_types=['card'],
saved_payment_method_options={'payment_method_save': 'enabled'},
success_url='http://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
cancel_url='http://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
success_url='https://test.com/api/billing/success?session_id={CHECKOUT_SESSION_ID}',
cancel_url='https://test.com/api/billing/cancel?session_id={CHECKOUT_SESSION_ID}',
)
# Verify database session creation
@@ -331,7 +331,7 @@ async def test_success_callback_success():
assert response.status_code == 302
assert (
response.headers['location']
== 'http://test.com/settings/billing?checkout=success'
== 'https://test.com/settings/billing?checkout=success'
)
# Verify LiteLLM API calls
@@ -402,7 +402,7 @@ async def test_cancel_callback_session_not_found():
assert response.status_code == 302
assert (
response.headers['location']
== 'http://test.com/settings/billing?checkout=cancel'
== 'https://test.com/settings/billing?checkout=cancel'
)
# Verify no database updates occurred
@@ -429,7 +429,7 @@ async def test_cancel_callback_success():
assert response.status_code == 302
assert (
response.headers['location']
== 'http://test.com/settings/billing?checkout=cancel'
== 'https://test.com/settings/billing?checkout=cancel'
)
# Verify database updates
@@ -490,7 +490,7 @@ async def test_create_customer_setup_session_success():
AsyncMock(return_value=mock_customer_info),
),
patch('stripe.checkout.Session.create_async', mock_create),
patch('server.routes.billing.validate_saas_environment'),
patch('server.routes.billing.validate_billing_enabled'),
):
result = await create_customer_setup_session(mock_request, 'mock_user')
@@ -502,6 +502,6 @@ async def test_create_customer_setup_session_success():
customer='mock-customer-id',
mode='setup',
payment_method_types=['card'],
success_url='http://test.com/?free_credits=success',
cancel_url='http://test.com/',
success_url='https://test.com/?free_credits=success',
cancel_url='https://test.com/',
)
@@ -1126,3 +1126,174 @@ class TestLiteLlmManager:
'http://test.url/team/delete',
json={'team_ids': [team_id]},
)
@pytest.mark.asyncio
async def test_remove_user_from_team_successful(self):
"""
GIVEN: Valid user_id and team_id
WHEN: _remove_user_from_team is called
THEN: HTTP POST is made to remove user from team
"""
mock_response = AsyncMock()
mock_response.is_success = True
mock_response.status_code = 200
with (
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'),
):
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
await LiteLlmManager._remove_user_from_team(
mock_client, 'test-user-id', 'test-team-id'
)
mock_client.post.assert_called_once_with(
'http://test.url/team/member_delete',
json={
'team_id': 'test-team-id',
'user_id': 'test-user-id',
},
)
@pytest.mark.asyncio
async def test_remove_user_from_team_not_found(self):
"""
GIVEN: User not in team
WHEN: _remove_user_from_team is called
THEN: 404 response is handled gracefully without raising
"""
mock_response = AsyncMock()
mock_response.is_success = False
mock_response.status_code = 404
mock_response.text = 'User not found in team'
mock_response.raise_for_status = MagicMock()
with (
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'),
):
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
# Should not raise an exception
await LiteLlmManager._remove_user_from_team(
mock_client, 'test-user-id', 'test-team-id'
)
@pytest.mark.asyncio
async def test_downgrade_entries_missing_config(self, mock_user_settings):
"""Test downgrade_entries when LiteLLM config is missing."""
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
result = await LiteLlmManager.downgrade_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
assert result is None
@pytest.mark.asyncio
async def test_downgrade_entries_team_not_found(self, mock_user_settings):
"""Test downgrade_entries when team is not found."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch.object(
LiteLlmManager, '_get_team', new_callable=AsyncMock
) as mock_get_team:
mock_get_team.return_value = None
result = await LiteLlmManager.downgrade_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
assert result is None
@pytest.mark.asyncio
async def test_downgrade_entries_successful(self, mock_user_settings):
"""Test successful downgrade_entries operation."""
mock_response = MagicMock()
mock_response.is_success = True
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
mock_team_info_response = MagicMock()
mock_team_info_response.is_success = True
mock_team_info_response.status_code = 200
mock_team_info_response.json.return_value = {
'team_info': {
'max_budget': 100.0,
'spend': 20.0,
},
'team_memberships': [
{
'user_id': 'test-user-id',
'team_id': 'test-org-id',
'max_budget_in_team': 100.0,
'spend': 20.0,
}
],
}
mock_team_info_response.raise_for_status = MagicMock()
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch(
'storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'default-team'
):
with patch('httpx.AsyncClient') as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = (
mock_client
)
mock_client.get.return_value = mock_team_info_response
mock_client.post.return_value = mock_response
result = await LiteLlmManager.downgrade_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
# downgrade_entries returns the user_settings
assert result is not None
assert result.agent == 'TestAgent'
# Verify downgrade steps were called:
# 1. get_team (GET)
# 2. get_user_team_info (GET via _get_team)
# 3. update_user (POST)
# 4. add_user_to_team (POST)
# 5. update_key (POST)
# 6. remove_user_from_team (POST)
# 7. delete_team (POST)
assert mock_client.get.call_count >= 1
assert mock_client.post.call_count >= 4
@pytest.mark.asyncio
async def test_downgrade_entries_local_deployment(self, mock_user_settings):
"""Test downgrade_entries in local deployment mode (skips LiteLLM calls)."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': 'true'}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
result = await LiteLlmManager.downgrade_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
# In local deployment, should return user_settings without
# making any LiteLLM calls
assert result is not None
assert result.agent == 'TestAgent'
+81
View File
@@ -68,3 +68,84 @@ def test_user_model(session_maker):
)
assert queried_org_member is not None
assert queried_org_member.llm_api_key.get_secret_value() == 'test-api-key'
def test_user_model_git_user_fields(session_maker):
"""Test that git_user_name and git_user_email columns exist and work correctly."""
with session_maker() as session:
# Arrange
org = Org(name='test_org_git')
session.add(org)
session.flush()
test_user_id = uuid4()
# Act
user = User(
id=test_user_id,
current_org_id=org.id,
git_user_name='Test Git Author',
git_user_email='git@example.com',
)
session.add(user)
session.commit()
# Assert
queried_user = session.query(User).filter(User.id == test_user_id).first()
assert queried_user.git_user_name == 'Test Git Author'
assert queried_user.git_user_email == 'git@example.com'
def test_user_model_git_user_fields_nullable(session_maker):
"""Test that git_user_name and git_user_email can be null."""
with session_maker() as session:
# Arrange
org = Org(name='test_org_nullable')
session.add(org)
session.flush()
test_user_id = uuid4()
# Act - create user without git fields
user = User(
id=test_user_id,
current_org_id=org.id,
)
session.add(user)
session.commit()
# Assert
queried_user = session.query(User).filter(User.id == test_user_id).first()
assert queried_user.git_user_name is None
assert queried_user.git_user_email is None
def test_user_model_git_user_fields_in_table_columns():
"""Test that git_user_name and git_user_email are in User table columns."""
# Arrange & Act
column_names = [c.name for c in User.__table__.columns]
# Assert
assert 'git_user_name' in column_names
assert 'git_user_email' in column_names
def test_user_model_git_user_fields_hasattr(session_maker):
"""Test that hasattr returns True for git_user_* fields on User model.
This verifies the fix for SaasSettingsStore.store() which uses hasattr
to determine if a field should be persisted to a model.
"""
with session_maker() as session:
# Arrange
org = Org(name='test_org_hasattr')
session.add(org)
session.flush()
user = User(id=uuid4(), current_org_id=org.id)
session.add(user)
session.flush()
# Assert - hasattr must return True for store() to work
assert hasattr(user, 'git_user_name')
assert hasattr(user, 'git_user_email')
File diff suppressed because it is too large Load Diff
+371
View File
@@ -415,3 +415,374 @@ def test_persist_org_with_owner_with_multiple_fields(session_maker, mock_litellm
)
assert persisted_member.max_iterations == 100
assert persisted_member.llm_model == 'gpt-4'
@pytest.mark.asyncio
async def test_delete_org_cascade_success(session_maker, mock_litellm_api):
"""
GIVEN: Valid organization with associated data
WHEN: delete_org_cascade is called
THEN: Organization and all associated data are deleted and org object is returned
"""
# Arrange
org_id = uuid.uuid4()
# Create expected return object
expected_org = Org(
id=org_id,
name='Test Organization',
contact_name='John Doe',
contact_email='john@example.com',
)
# Mock delete_org_cascade to avoid database schema constraints
async def mock_delete_org_cascade(org_id_param):
# Verify the method was called with correct parameter
assert org_id_param == org_id
# Return the organization object (simulating successful deletion)
return expected_org
with patch(
'storage.org_store.OrgStore.delete_org_cascade', mock_delete_org_cascade
):
# Act
result = await OrgStore.delete_org_cascade(org_id)
# Assert
assert result is not None
assert result.id == org_id
assert result.name == 'Test Organization'
assert result.contact_name == 'John Doe'
assert result.contact_email == 'john@example.com'
@pytest.mark.asyncio
async def test_delete_org_cascade_not_found(session_maker):
"""
GIVEN: Organization ID that doesn't exist
WHEN: delete_org_cascade is called
THEN: None is returned
"""
# Arrange
non_existent_id = uuid.uuid4()
with patch('storage.org_store.session_maker', session_maker):
# Act
result = await OrgStore.delete_org_cascade(non_existent_id)
# Assert
assert result is None
@pytest.mark.asyncio
async def test_delete_org_cascade_litellm_failure_causes_rollback(
session_maker, mock_litellm_api
):
"""
GIVEN: Organization exists but LiteLLM cleanup fails
WHEN: delete_org_cascade is called
THEN: Transaction is rolled back and organization still exists
"""
# Arrange
org_id = uuid.uuid4()
user_id = uuid.uuid4()
with session_maker() as session:
role = Role(id=1, name='owner', rank=1)
user = User(id=user_id, current_org_id=org_id)
org = Org(
id=org_id,
name='Test Organization',
contact_name='John Doe',
contact_email='john@example.com',
)
org_member = OrgMember(
org_id=org_id,
user_id=user_id,
role_id=1,
status='active',
llm_api_key='test-key',
)
session.add_all([role, user, org, org_member])
session.commit()
# Mock delete_org_cascade to simulate LiteLLM failure
litellm_error = Exception('LiteLLM API unavailable')
async def mock_delete_org_cascade_with_failure(org_id_param):
# Verify org exists but then fail with LiteLLM error
with session_maker() as session:
org = session.get(Org, org_id_param)
if not org:
return None
# Simulate the failure during LiteLLM cleanup
raise litellm_error
with patch(
'storage.org_store.OrgStore.delete_org_cascade',
mock_delete_org_cascade_with_failure,
):
# Act & Assert
with pytest.raises(Exception) as exc_info:
await OrgStore.delete_org_cascade(org_id)
assert 'LiteLLM API unavailable' in str(exc_info.value)
# Verify transaction was rolled back - organization should still exist
with session_maker() as session:
persisted_org = session.get(Org, org_id)
assert persisted_org is not None
assert persisted_org.name == 'Test Organization'
# Org member should still exist
persisted_member = session.query(OrgMember).filter_by(org_id=org_id).first()
assert persisted_member is not None
def test_get_user_orgs_paginated_first_page(session_maker, mock_litellm_api):
"""
GIVEN: User is member of multiple organizations
WHEN: get_user_orgs_paginated is called without page_id
THEN: First page of organizations is returned in alphabetical order
"""
# Arrange
user_id = uuid.uuid4()
other_user_id = uuid.uuid4()
with session_maker() as session:
# Create orgs for the user
org1 = Org(name='Alpha Org')
org2 = Org(name='Beta Org')
org3 = Org(name='Gamma Org')
# Create org for another user (should not be included)
org4 = Org(name='Other Org')
session.add_all([org1, org2, org3, org4])
session.flush()
# Create user and role
user = User(id=user_id, current_org_id=org1.id)
other_user = User(id=other_user_id, current_org_id=org4.id)
role = Role(id=1, name='member', rank=2)
session.add_all([user, other_user, role])
session.flush()
# Create memberships
member1 = OrgMember(
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
)
member2 = OrgMember(
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
)
member3 = OrgMember(
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
)
other_member = OrgMember(
org_id=org4.id, user_id=other_user_id, role_id=1, llm_api_key='key4'
)
session.add_all([member1, member2, member3, other_member])
session.commit()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id=None, limit=2
)
# Assert
assert len(orgs) == 2
assert orgs[0].name == 'Alpha Org'
assert orgs[1].name == 'Beta Org'
assert next_page_id == '2' # Has more results
# Verify other user's org is not included
org_names = [org.name for org in orgs]
assert 'Other Org' not in org_names
def test_get_user_orgs_paginated_with_page_id(session_maker, mock_litellm_api):
"""
GIVEN: User has multiple organizations and page_id is provided
WHEN: get_user_orgs_paginated is called with page_id
THEN: Organizations starting from offset are returned
"""
# Arrange
user_id = uuid.uuid4()
with session_maker() as session:
org1 = Org(name='Alpha Org')
org2 = Org(name='Beta Org')
org3 = Org(name='Gamma Org')
session.add_all([org1, org2, org3])
session.flush()
user = User(id=user_id, current_org_id=org1.id)
role = Role(id=1, name='member', rank=2)
session.add_all([user, role])
session.flush()
member1 = OrgMember(
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
)
member2 = OrgMember(
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
)
member3 = OrgMember(
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
)
session.add_all([member1, member2, member3])
session.commit()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id='1', limit=1
)
# Assert
assert len(orgs) == 1
assert orgs[0].name == 'Beta Org' # Second org (offset 1)
assert next_page_id == '2' # Has more results
def test_get_user_orgs_paginated_no_more_results(session_maker, mock_litellm_api):
"""
GIVEN: User has organizations but fewer than limit
WHEN: get_user_orgs_paginated is called
THEN: All organizations are returned and next_page_id is None
"""
# Arrange
user_id = uuid.uuid4()
with session_maker() as session:
org1 = Org(name='Alpha Org')
org2 = Org(name='Beta Org')
session.add_all([org1, org2])
session.flush()
user = User(id=user_id, current_org_id=org1.id)
role = Role(id=1, name='member', rank=2)
session.add_all([user, role])
session.flush()
member1 = OrgMember(
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
)
member2 = OrgMember(
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
)
session.add_all([member1, member2])
session.commit()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id=None, limit=10
)
# Assert
assert len(orgs) == 2
assert next_page_id is None
def test_get_user_orgs_paginated_invalid_page_id(session_maker, mock_litellm_api):
"""
GIVEN: Invalid page_id (non-numeric string)
WHEN: get_user_orgs_paginated is called
THEN: Results start from beginning (offset 0)
"""
# Arrange
user_id = uuid.uuid4()
with session_maker() as session:
org1 = Org(name='Alpha Org')
session.add(org1)
session.flush()
user = User(id=user_id, current_org_id=org1.id)
role = Role(id=1, name='member', rank=2)
session.add_all([user, role])
session.flush()
member1 = OrgMember(
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
)
session.add(member1)
session.commit()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id='invalid', limit=10
)
# Assert
assert len(orgs) == 1
assert orgs[0].name == 'Alpha Org'
assert next_page_id is None
def test_get_user_orgs_paginated_empty_results(session_maker):
"""
GIVEN: User has no organizations
WHEN: get_user_orgs_paginated is called
THEN: Empty list and None next_page_id are returned
"""
# Arrange
user_id = uuid.uuid4()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, next_page_id = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id=None, limit=10
)
# Assert
assert len(orgs) == 0
assert next_page_id is None
def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api):
"""
GIVEN: User has organizations with different names
WHEN: get_user_orgs_paginated is called
THEN: Organizations are returned in alphabetical order by name
"""
# Arrange
user_id = uuid.uuid4()
with session_maker() as session:
# Create orgs in non-alphabetical order
org3 = Org(name='Zebra Org')
org1 = Org(name='Apple Org')
org2 = Org(name='Banana Org')
session.add_all([org3, org1, org2])
session.flush()
user = User(id=user_id, current_org_id=org1.id)
role = Role(id=1, name='member', rank=2)
session.add_all([user, role])
session.flush()
member1 = OrgMember(
org_id=org1.id, user_id=user_id, role_id=1, llm_api_key='key1'
)
member2 = OrgMember(
org_id=org2.id, user_id=user_id, role_id=1, llm_api_key='key2'
)
member3 = OrgMember(
org_id=org3.id, user_id=user_id, role_id=1, llm_api_key='key3'
)
session.add_all([member1, member2, member3])
session.commit()
# Act
with patch('storage.org_store.session_maker', session_maker):
orgs, _ = OrgStore.get_user_orgs_paginated(
user_id=user_id, page_id=None, limit=10
)
# Assert
assert len(orgs) == 3
assert orgs[0].name == 'Apple Org'
assert orgs[1].name == 'Banana Org'
assert orgs[2].name == 'Zebra Org'
+91 -4
View File
@@ -1,9 +1,36 @@
from unittest.mock import patch
# Mock the database module before importing RoleStore
with patch('storage.database.engine'), patch('storage.database.a_engine'):
from storage.role import Role
from storage.role_store import RoleStore
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.role import Role
from storage.role_store import RoleStore
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker for testing."""
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
def test_get_role_by_id(session_maker):
@@ -81,3 +108,63 @@ def test_create_role(session_maker):
assert role.name == 'moderator'
assert role.rank == 2
assert role.id is not None
@pytest.mark.asyncio
async def test_get_role_by_name_async_with_session(async_session_maker):
"""Test getting role by name asynchronously with an explicit session."""
# Create a test role
async with async_session_maker() as session:
role = Role(name='admin', rank=1)
session.add(role)
await session.commit()
await session.refresh(role)
role_id = role.id
# Test retrieval with explicit session
async with async_session_maker() as session:
retrieved_role = await RoleStore.get_role_by_name_async(
'admin', session=session
)
assert retrieved_role is not None
assert retrieved_role.id == role_id
assert retrieved_role.name == 'admin'
assert retrieved_role.rank == 1
@pytest.mark.asyncio
async def test_get_role_by_name_async_without_session(async_session_maker):
"""Test getting role by name asynchronously using internal session maker."""
# Create a test role
async with async_session_maker() as session:
role = Role(name='editor', rank=2)
session.add(role)
await session.commit()
await session.refresh(role)
role_id = role.id
# Test retrieval without explicit session (using patched a_session_maker)
with patch('storage.role_store.a_session_maker', async_session_maker):
retrieved_role = await RoleStore.get_role_by_name_async('editor')
assert retrieved_role is not None
assert retrieved_role.id == role_id
assert retrieved_role.name == 'editor'
assert retrieved_role.rank == 2
@pytest.mark.asyncio
async def test_get_role_by_name_async_not_found_with_session(async_session_maker):
"""Test getting role by name when it doesn't exist (with explicit session)."""
async with async_session_maker() as session:
retrieved_role = await RoleStore.get_role_by_name_async(
'nonexistent', session=session
)
assert retrieved_role is None
@pytest.mark.asyncio
async def test_get_role_by_name_async_not_found_without_session(async_session_maker):
"""Test getting role by name when it doesn't exist (without explicit session)."""
with patch('storage.role_store.a_session_maker', async_session_maker):
retrieved_role = await RoleStore.get_role_by_name_async('nonexistent')
assert retrieved_role is None
@@ -2,7 +2,7 @@
from datetime import UTC, datetime
from typing import AsyncGenerator
from uuid import uuid4
from uuid import UUID, uuid4
import pytest
from server.sharing.shared_conversation_models import (
@@ -13,6 +13,9 @@ from server.sharing.sql_shared_conversation_info_service import (
)
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.org import Org
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.user import User
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
@@ -428,3 +431,261 @@ class TestSharedConversationInfoService:
page1_ids = {item.id for item in result.items}
page2_ids = {item.id for item in result2.items}
assert page1_ids.isdisjoint(page2_ids)
class TestSharedConversationInfoServiceWithSaasMetadata:
"""Test cases for SharedConversationInfoService with SAAS metadata.
These tests verify that created_by_user_id is correctly retrieved from
the conversation_metadata_saas table when it exists.
"""
@pytest.fixture
async def async_engine_with_saas(self):
"""Create an async SQLite engine with all SAAS tables."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session_with_saas(
self, async_engine_with_saas
) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session for testing with SAAS tables."""
async_session_maker = async_sessionmaker(
async_engine_with_saas, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
yield db_session
@pytest.fixture
async def test_org(self, async_session_with_saas) -> Org:
"""Create a test organization."""
org = Org(id=uuid4(), name=f'test_org_{uuid4().hex[:8]}')
async_session_with_saas.add(org)
await async_session_with_saas.commit()
return org
@pytest.fixture
async def test_user(self, async_session_with_saas, test_org) -> User:
"""Create a test user belonging to the test organization."""
user = User(id=uuid4(), current_org_id=test_org.id)
async_session_with_saas.add(user)
await async_session_with_saas.commit()
return user
@pytest.fixture
async def shared_service_with_saas(self, async_session_with_saas):
"""Create a SharedConversationInfoService for testing."""
return SQLSharedConversationInfoService(db_session=async_session_with_saas)
@pytest.fixture
async def app_service_with_saas(self, async_session_with_saas):
"""Create an AppConversationInfoService for creating test data."""
return SQLAppConversationInfoService(
db_session=async_session_with_saas,
user_context=SpecifyUserContext(user_id=None),
)
async def _create_saas_metadata(
self,
db_session: AsyncSession,
conversation_id: UUID,
user_id: UUID,
org_id: UUID,
) -> StoredConversationMetadataSaas:
"""Helper to create SAAS metadata for a conversation."""
saas_metadata = StoredConversationMetadataSaas(
conversation_id=str(conversation_id),
user_id=user_id,
org_id=org_id,
)
db_session.add(saas_metadata)
await db_session.commit()
return saas_metadata
@pytest.mark.asyncio
async def test_get_shared_conversation_returns_user_id_from_saas_metadata(
self,
shared_service_with_saas,
app_service_with_saas,
async_session_with_saas,
test_user,
test_org,
):
"""Test that get_shared_conversation_info returns created_by_user_id from SAAS metadata."""
# Arrange
conversation_id = uuid4()
conversation = AppConversationInfo(
id=conversation_id,
created_by_user_id=None,
sandbox_id='test_sandbox',
title='Public Conversation With User',
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_service_with_saas.save_app_conversation_info(conversation)
await self._create_saas_metadata(
async_session_with_saas, conversation_id, test_user.id, test_org.id
)
# Act
result = await shared_service_with_saas.get_shared_conversation_info(
conversation_id
)
# Assert
assert result is not None
assert result.created_by_user_id == str(test_user.id)
@pytest.mark.asyncio
async def test_search_shared_conversations_returns_user_id_from_saas_metadata(
self,
shared_service_with_saas,
app_service_with_saas,
async_session_with_saas,
test_user,
test_org,
):
"""Test that search_shared_conversation_info returns created_by_user_id from SAAS metadata."""
# Arrange
conversation_id = uuid4()
conversation = AppConversationInfo(
id=conversation_id,
created_by_user_id=None,
sandbox_id='test_sandbox_search',
title='Searchable Public Conversation',
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_service_with_saas.save_app_conversation_info(conversation)
await self._create_saas_metadata(
async_session_with_saas, conversation_id, test_user.id, test_org.id
)
# Act
result = await shared_service_with_saas.search_shared_conversation_info()
# Assert
assert len(result.items) == 1
assert result.items[0].created_by_user_id == str(test_user.id)
@pytest.mark.asyncio
async def test_batch_get_shared_conversations_returns_user_id_from_saas_metadata(
self,
shared_service_with_saas,
app_service_with_saas,
async_session_with_saas,
test_user,
test_org,
):
"""Test that batch_get_shared_conversation_info returns created_by_user_id from SAAS metadata."""
# Arrange
conversation_id = uuid4()
conversation = AppConversationInfo(
id=conversation_id,
created_by_user_id=None,
sandbox_id='test_sandbox_batch',
title='Batch Get Conversation',
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_service_with_saas.save_app_conversation_info(conversation)
await self._create_saas_metadata(
async_session_with_saas, conversation_id, test_user.id, test_org.id
)
# Act
result = await shared_service_with_saas.batch_get_shared_conversation_info(
[conversation_id]
)
# Assert
assert len(result) == 1
assert result[0] is not None
assert result[0].created_by_user_id == str(test_user.id)
@pytest.mark.asyncio
async def test_mixed_conversations_with_and_without_saas_metadata(
self,
shared_service_with_saas,
app_service_with_saas,
async_session_with_saas,
test_user,
test_org,
):
"""Test handling of conversations where some have SAAS metadata and some don't."""
# Arrange
conv_with_saas_id = uuid4()
conv_without_saas_id = uuid4()
conv_with_saas = AppConversationInfo(
id=conv_with_saas_id,
created_by_user_id=None,
sandbox_id='sandbox_with_saas',
title='With SAAS Metadata',
created_at=datetime(2023, 1, 2, tzinfo=UTC),
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conv_without_saas = AppConversationInfo(
id=conv_without_saas_id,
created_by_user_id=None,
sandbox_id='sandbox_without_saas',
title='Without SAAS Metadata',
created_at=datetime(2023, 1, 1, tzinfo=UTC),
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_service_with_saas.save_app_conversation_info(conv_with_saas)
await app_service_with_saas.save_app_conversation_info(conv_without_saas)
await self._create_saas_metadata(
async_session_with_saas, conv_with_saas_id, test_user.id, test_org.id
)
# Act
result = await shared_service_with_saas.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT
)
# Assert
assert len(result.items) == 2
conv_without = next(
item for item in result.items if item.id == conv_without_saas_id
)
conv_with = next(item for item in result.items if item.id == conv_with_saas_id)
assert conv_without.created_by_user_id is None
assert conv_with.created_by_user_id == str(test_user.id)
@@ -0,0 +1,159 @@
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import { screen } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { renderWithProviders } from "test-utils";
import { PlanPreview } from "#/components/features/chat/plan-preview";
// Mock the feature flag to always return true (not testing feature flag behavior)
vi.mock("#/utils/feature-flags", () => ({
USE_PLANNING_AGENT: vi.fn(() => true),
}));
// Mock i18n - need to preserve initReactI18next and I18nextProvider for test-utils
vi.mock("react-i18next", async (importOriginal) => {
const actual = await importOriginal<typeof import("react-i18next")>();
return {
...actual,
useTranslation: () => ({
t: (key: string) => key,
}),
};
});
describe("PlanPreview", () => {
beforeEach(() => {
vi.clearAllMocks();
});
afterEach(() => {
vi.clearAllMocks();
});
it("should render nothing when planContent is null", () => {
renderWithProviders(<PlanPreview planContent={null} />);
const contentDiv = screen.getByTestId("plan-preview-content");
expect(contentDiv).toBeInTheDocument();
expect(contentDiv.textContent?.trim() || "").toBe("");
});
it("should render nothing when planContent is undefined", () => {
renderWithProviders(<PlanPreview planContent={undefined} />);
const contentDiv = screen.getByTestId("plan-preview-content");
expect(contentDiv).toBeInTheDocument();
expect(contentDiv.textContent?.trim() || "").toBe("");
});
it("should render markdown content when planContent is provided", () => {
const planContent = "# Plan Title\n\nThis is the plan content.";
const { container } = renderWithProviders(
<PlanPreview planContent={planContent} />,
);
// Check that component rendered and contains the content (markdown may break up text)
expect(container.firstChild).not.toBeNull();
expect(container.textContent).toContain("Plan Title");
expect(container.textContent).toContain("This is the plan content.");
});
it("should render full content when length is less than or equal to 300 characters", () => {
const planContent = "A".repeat(300);
const { container } = renderWithProviders(
<PlanPreview planContent={planContent} />,
);
// Content should be present (may be broken up by markdown)
expect(container.textContent).toContain(planContent);
expect(screen.queryByText(/COMMON\$READ_MORE/i)).not.toBeInTheDocument();
});
it("should truncate content when length exceeds 300 characters", () => {
const longContent = "A".repeat(350);
const { container } = renderWithProviders(
<PlanPreview planContent={longContent} />,
);
// Truncated content should be present (may be broken up by markdown)
expect(container.textContent).toContain("A".repeat(300));
expect(container.textContent).toContain("...");
expect(container.textContent).toContain("COMMON$READ_MORE");
});
it("should call onViewClick when View button is clicked", async () => {
const user = userEvent.setup();
const onViewClick = vi.fn();
renderWithProviders(
<PlanPreview planContent="Plan content" onViewClick={onViewClick} />,
);
const viewButton = screen.getByTestId("plan-preview-view-button");
expect(viewButton).toBeInTheDocument();
await user.click(viewButton);
expect(onViewClick).toHaveBeenCalledTimes(1);
});
it("should call onViewClick when Read More button is clicked", async () => {
const user = userEvent.setup();
const onViewClick = vi.fn();
const longContent = "A".repeat(350);
renderWithProviders(
<PlanPreview planContent={longContent} onViewClick={onViewClick} />,
);
const readMoreButton = screen.getByTestId("plan-preview-read-more-button");
expect(readMoreButton).toBeInTheDocument();
await user.click(readMoreButton);
expect(onViewClick).toHaveBeenCalledTimes(1);
});
it("should call onBuildClick when Build button is clicked", async () => {
const user = userEvent.setup();
const onBuildClick = vi.fn();
renderWithProviders(
<PlanPreview planContent="Plan content" onBuildClick={onBuildClick} />,
);
const buildButton = screen.getByTestId("plan-preview-build-button");
expect(buildButton).toBeInTheDocument();
await user.click(buildButton);
expect(onBuildClick).toHaveBeenCalledTimes(1);
});
it("should render header with PLAN_MD text", () => {
const { container } = renderWithProviders(
<PlanPreview planContent="Plan content" />,
);
// Check that the translation key is rendered (i18n mock returns the key)
expect(container.textContent).toContain("COMMON$PLAN_MD");
});
it("should render plan content", () => {
const planContent = `# Heading 1
## Heading 2
- List item 1
- List item 2
**Bold text** and *italic text*`;
const { container } = renderWithProviders(
<PlanPreview planContent={planContent} />,
);
expect(container.textContent).toContain("Heading 1");
expect(container.textContent).toContain("Heading 2");
});
});
@@ -0,0 +1,186 @@
import { render, screen, waitFor } from "@testing-library/react";
import { describe, expect, vi, beforeEach, it } from "vitest";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import userEvent from "@testing-library/user-event";
import { GitBranchDropdown } from "../../../../src/components/features/home/git-branch-dropdown/git-branch-dropdown";
import { Branch } from "#/types/git";
// Mock the branch data hook
const mockUseBranchData = vi.fn();
vi.mock("#/hooks/query/use-branch-data", () => ({
useBranchData: (...args: unknown[]) => mockUseBranchData(...args),
}));
const MOCK_BRANCHES: Branch[] = [
{ name: "main", commit_sha: "abc123", protected: true },
{ name: "develop", commit_sha: "def456", protected: false },
{ name: "feature/test", commit_sha: "ghi789", protected: false },
];
const mockOnBranchSelect = vi.fn();
const renderDropdown = (
props: Partial<Parameters<typeof GitBranchDropdown>[0]> = {},
) => {
// Default mock return value
mockUseBranchData.mockReturnValue({
branches: MOCK_BRANCHES,
isLoading: false,
isError: false,
fetchNextPage: vi.fn(),
hasNextPage: false,
isFetchingNextPage: false,
isSearchLoading: false,
});
return render(
<GitBranchDropdown
repository="user/repo"
provider="github"
selectedBranch={null}
onBranchSelect={mockOnBranchSelect}
// eslint-disable-next-line react/jsx-props-no-spreading
{...props}
/>,
{
wrapper: ({ children }) => (
<QueryClientProvider
client={
new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
}
>
{children}
</QueryClientProvider>
),
},
);
};
describe("GitBranchDropdown", () => {
beforeEach(() => {
vi.clearAllMocks();
});
describe("dropdown behavior", () => {
it("should open dropdown when input is clicked", async () => {
renderDropdown();
const input = screen.getByTestId("git-branch-dropdown-input");
await userEvent.click(input);
// Dropdown should be open (menu should be visible)
await waitFor(() => {
expect(
screen.getByTestId("git-branch-dropdown-menu"),
).toBeInTheDocument();
});
});
it("should keep dropdown open when clicking input while already open", async () => {
renderDropdown();
const input = screen.getByTestId("git-branch-dropdown-input");
// First click - open dropdown
await userEvent.click(input);
await waitFor(() => {
expect(
screen.getByTestId("git-branch-dropdown-menu"),
).toBeInTheDocument();
});
// Second click on input - should stay open (not close)
await userEvent.click(input);
// Dropdown should still be open
await waitFor(() => {
expect(
screen.getByTestId("git-branch-dropdown-menu"),
).toBeInTheDocument();
});
});
it("should preserve typed text when clicking input while typing", async () => {
renderDropdown();
const input = screen.getByTestId(
"git-branch-dropdown-input",
) as HTMLInputElement;
// Click to open and type
await userEvent.click(input);
await userEvent.type(input, "feat");
expect(input.value).toBe("feat");
// Click on input again (should not reset text)
await userEvent.click(input);
// Text should be preserved
expect(input.value).toBe("feat");
});
});
describe("cursor position preservation", () => {
it("should allow editing in the middle of input text", async () => {
renderDropdown();
const input = screen.getByTestId(
"git-branch-dropdown-input",
) as HTMLInputElement;
// Click and type initial text
await userEvent.click(input);
await userEvent.type(input, "hello");
expect(input.value).toBe("hello");
// Move cursor to position 2 and type
input.setSelectionRange(2, 2);
await userEvent.type(input, "X");
// The character should be inserted (exact position may vary based on browser behavior)
expect(input.value).toContain("X");
});
});
describe("input synchronization", () => {
it("should show selected branch name in input when provided", async () => {
const selectedBranch = MOCK_BRANCHES[0];
renderDropdown({ selectedBranch });
const input = screen.getByTestId(
"git-branch-dropdown-input",
) as HTMLInputElement;
await waitFor(() => {
expect(input.value).toBe(selectedBranch.name);
});
});
});
describe("branch selection", () => {
it("should call onBranchSelect when a branch is selected", async () => {
renderDropdown();
const input = screen.getByTestId("git-branch-dropdown-input");
await userEvent.click(input);
// Wait for dropdown to open and show branches
await waitFor(() => {
expect(screen.getByText("main")).toBeInTheDocument();
});
// Click on a branch
await userEvent.click(screen.getByText("develop"));
expect(mockOnBranchSelect).toHaveBeenCalledWith(MOCK_BRANCHES[1]);
});
});
});
@@ -0,0 +1,234 @@
import { render, screen, waitFor } from "@testing-library/react";
import { describe, expect, vi, beforeEach, it } from "vitest";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import userEvent from "@testing-library/user-event";
import { GitRepoDropdown } from "../../../../src/components/features/home/git-repo-dropdown/git-repo-dropdown";
import { GitRepository } from "#/types/git";
// Mock the repository data hook
const mockUseRepositoryData = vi.fn();
vi.mock(
"#/components/features/home/git-repo-dropdown/use-repository-data",
() => ({
useRepositoryData: (...args: unknown[]) => mockUseRepositoryData(...args),
}),
);
// Mock the URL search hook
const mockUseUrlSearch = vi.fn();
vi.mock("#/components/features/home/git-repo-dropdown/use-url-search", () => ({
useUrlSearch: (...args: unknown[]) => mockUseUrlSearch(...args),
}));
// Mock useConfig
vi.mock("#/hooks/query/use-config", () => ({
useConfig: () => ({ data: null }),
}));
// Mock useHomeStore
vi.mock("#/stores/home-store", () => ({
useHomeStore: () => ({ recentRepositories: [] }),
}));
const MOCK_REPOSITORIES: GitRepository[] = [
{
id: "1",
full_name: "user/repo-one",
git_provider: "github",
is_public: true,
},
{
id: "2",
full_name: "user/repo-two",
git_provider: "github",
is_public: true,
},
{
id: "3",
full_name: "org/feature-repo",
git_provider: "github",
is_public: false,
},
];
const mockOnChange = vi.fn();
const setupDefaultMocks = (
repositoryDataOverrides: Partial<
ReturnType<typeof mockUseRepositoryData>
> = {},
) => {
mockUseRepositoryData.mockReturnValue({
repositories: MOCK_REPOSITORIES,
selectedRepository: null,
isLoading: false,
isError: false,
fetchNextPage: vi.fn(),
hasNextPage: false,
isFetchingNextPage: false,
isSearchLoading: false,
...repositoryDataOverrides,
});
mockUseUrlSearch.mockReturnValue({
urlSearchResults: [],
isUrlSearchLoading: false,
});
};
const renderDropdown = (
props: Partial<Parameters<typeof GitRepoDropdown>[0]> = {},
repositoryDataOverrides: Partial<
ReturnType<typeof mockUseRepositoryData>
> = {},
) => {
// Set up mocks with optional overrides
setupDefaultMocks(repositoryDataOverrides);
return render(
<GitRepoDropdown
provider="github"
onChange={mockOnChange}
// eslint-disable-next-line react/jsx-props-no-spreading
{...props}
/>,
{
wrapper: ({ children }) => (
<QueryClientProvider
client={
new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
}
>
{children}
</QueryClientProvider>
),
},
);
};
describe("GitRepoDropdown", () => {
beforeEach(() => {
vi.clearAllMocks();
});
describe("dropdown behavior", () => {
it("should open dropdown when input is clicked", async () => {
renderDropdown();
const input = screen.getByTestId("git-repo-dropdown");
await userEvent.click(input);
// Dropdown should be open (menu should be visible)
await waitFor(() => {
expect(
screen.getByTestId("git-repo-dropdown-menu"),
).toBeInTheDocument();
});
});
it("should keep dropdown open when clicking input while already open", async () => {
renderDropdown();
const input = screen.getByTestId("git-repo-dropdown");
// First click - open dropdown
await userEvent.click(input);
await waitFor(() => {
expect(
screen.getByTestId("git-repo-dropdown-menu"),
).toBeInTheDocument();
});
// Second click on input - should stay open (not close)
await userEvent.click(input);
// Dropdown should still be open
await waitFor(() => {
expect(
screen.getByTestId("git-repo-dropdown-menu"),
).toBeInTheDocument();
});
});
it("should preserve typed text when clicking input while typing", async () => {
renderDropdown();
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
// Click to open and type
await userEvent.click(input);
await userEvent.type(input, "repo");
expect(input.value).toBe("repo");
// Click on input again (should not reset text)
await userEvent.click(input);
// Text should be preserved
expect(input.value).toBe("repo");
});
});
describe("cursor position preservation", () => {
it("should allow editing in the middle of input text", async () => {
renderDropdown();
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
// Click and type initial text
await userEvent.click(input);
await userEvent.type(input, "hello");
expect(input.value).toBe("hello");
// Move cursor to position 2 and type
input.setSelectionRange(2, 2);
await userEvent.type(input, "X");
// The character should be inserted (exact position may vary based on browser behavior)
expect(input.value).toContain("X");
});
});
describe("input synchronization", () => {
it("should show selected repository name in input when provided", async () => {
const selectedRepository = MOCK_REPOSITORIES[0];
renderDropdown(
{ value: selectedRepository.full_name },
{ selectedRepository },
);
const input = screen.getByTestId("git-repo-dropdown") as HTMLInputElement;
await waitFor(() => {
expect(input.value).toBe(selectedRepository.full_name);
});
});
});
describe("repository selection", () => {
it("should call onChange when a repository is selected", async () => {
renderDropdown();
const input = screen.getByTestId("git-repo-dropdown");
await userEvent.click(input);
// Wait for dropdown to open and show repositories
await waitFor(() => {
expect(screen.getByText("user/repo-one")).toBeInTheDocument();
});
// Click on a repository
await userEvent.click(screen.getByText("user/repo-two"));
expect(mockOnChange).toHaveBeenCalledWith(MOCK_REPOSITORIES[1]);
});
});
});
@@ -0,0 +1,35 @@
import { describe, expect, it } from "vitest";
import { shouldRenderEvent } from "#/components/v1/chat/event-content-helpers/should-render-event";
import {
createPlanningFileEditorActionEvent,
createOtherActionEvent,
createPlanningObservationEvent,
createUserMessageEvent,
} from "test-utils";
describe("shouldRenderEvent - PlanningFileEditorAction", () => {
it("should return false for PlanningFileEditorAction", () => {
const event = createPlanningFileEditorActionEvent("action-1");
expect(shouldRenderEvent(event)).toBe(false);
});
it("should return true for other action types", () => {
const event = createOtherActionEvent("action-1");
expect(shouldRenderEvent(event)).toBe(true);
});
it("should return true for PlanningFileEditorObservation", () => {
const event = createPlanningObservationEvent("obs-1");
// Observations should still render (they're handled separately in event-message)
expect(shouldRenderEvent(event)).toBe(true);
});
it("should return true for user message events", () => {
const event = createUserMessageEvent("msg-1");
expect(shouldRenderEvent(event)).toBe(true);
});
});
@@ -0,0 +1,159 @@
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
import { screen, render } from "@testing-library/react";
import { EventMessage } from "#/components/v1/chat/event-message";
import { useConversationStore } from "#/stores/conversation-store";
import {
renderWithProviders,
createPlanningObservationEvent,
} from "test-utils";
// Mock the feature flag
vi.mock("#/utils/feature-flags", () => ({
USE_PLANNING_AGENT: vi.fn(() => true),
}));
// Mock useConfig
vi.mock("#/hooks/query/use-config", () => ({
useConfig: () => ({
data: { APP_MODE: "saas" },
}),
}));
// Mock PlanPreview component to verify it's rendered
vi.mock("#/components/features/chat/plan-preview", () => ({
PlanPreview: ({ planContent }: { planContent?: string | null }) => (
<div data-testid="plan-preview">Plan Preview: {planContent || "null"}</div>
),
}));
describe("EventMessage - PlanPreview rendering", () => {
beforeEach(() => {
vi.clearAllMocks();
// Reset conversation store
useConversationStore.setState({
planContent: null,
});
});
afterEach(() => {
vi.clearAllMocks();
});
it("should render PlanPreview when PlanningFileEditorObservation event ID is in planPreviewEventIds", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const planPreviewEventIds = new Set(["plan-obs-1"]);
const planContent = "This is the plan content";
useConversationStore.setState({ planContent });
renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={planPreviewEventIds}
/>,
);
expect(screen.getByTestId("plan-preview")).toBeInTheDocument();
expect(
screen.getByText(`Plan Preview: ${planContent}`),
).toBeInTheDocument();
});
it("should return null when PlanningFileEditorObservation event ID is NOT in planPreviewEventIds", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const planPreviewEventIds = new Set(["plan-obs-2"]); // Different ID
const { container } = renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={planPreviewEventIds}
/>,
);
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
expect(container.firstChild).toBeNull();
});
it("should return null when planPreviewEventIds is undefined", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const { container } = renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={undefined}
/>,
);
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
expect(container.firstChild).toBeNull();
});
it("should use planContent from conversation store", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const planPreviewEventIds = new Set(["plan-obs-1"]);
const planContent = "Store plan content";
useConversationStore.setState({ planContent });
renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={planPreviewEventIds}
/>,
);
expect(
screen.getByText(`Plan Preview: ${planContent}`),
).toBeInTheDocument();
});
it("should handle null planContent from store", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const planPreviewEventIds = new Set(["plan-obs-1"]);
useConversationStore.setState({ planContent: null });
renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={planPreviewEventIds}
/>,
);
expect(screen.getByTestId("plan-preview")).toBeInTheDocument();
expect(screen.getByText("Plan Preview: null")).toBeInTheDocument();
});
it("should handle empty planPreviewEventIds set", () => {
const event = createPlanningObservationEvent("plan-obs-1");
const planPreviewEventIds = new Set<string>();
const { container } = renderWithProviders(
<EventMessage
event={event}
messages={[]}
isLastMessage={false}
isInLast10Actions={false}
planPreviewEventIds={planPreviewEventIds}
/>,
);
expect(screen.queryByTestId("plan-preview")).not.toBeInTheDocument();
expect(container.firstChild).toBeNull();
});
});
@@ -0,0 +1,195 @@
import { renderHook } from "@testing-library/react";
import { describe, expect, it } from "vitest";
import {
usePlanPreviewEvents,
shouldShowPlanPreview,
} from "#/components/v1/chat/hooks/use-plan-preview-events";
import {
OpenHandsEvent,
MessageEvent,
ObservationEvent,
PlanningFileEditorObservation,
} from "#/types/v1/core";
// Helper to create a user message event
const createUserMessageEvent = (id: string): MessageEvent => ({
id,
timestamp: new Date().toISOString(),
source: "user",
llm_message: {
role: "user",
content: [{ type: "text", text: "User message" }],
},
activated_microagents: [],
extended_content: [],
});
// Helper to create a PlanningFileEditorObservation event
const createPlanningObservationEvent = (
id: string,
actionId: string = "action-1",
): ObservationEvent<PlanningFileEditorObservation> => ({
id,
timestamp: new Date().toISOString(),
source: "environment",
tool_name: "planning_file_editor",
tool_call_id: "call-1",
action_id: actionId,
observation: {
kind: "PlanningFileEditorObservation",
content: [{ type: "text", text: "Plan content" }],
is_error: false,
command: "create",
path: "/workspace/PLAN.md",
prev_exist: false,
old_content: null,
new_content: "Plan content",
},
});
// Helper to create a non-planning observation event
const createOtherObservationEvent = (id: string): ObservationEvent => ({
id,
timestamp: new Date().toISOString(),
source: "environment",
tool_name: "execute_bash",
tool_call_id: "call-1",
action_id: "action-1",
observation: {
kind: "ExecuteBashObservation",
content: [{ type: "text", text: "output" }],
command: "echo test",
exit_code: 0,
error: false,
timeout: false,
metadata: {
exit_code: 0,
pid: 12345,
username: "user",
hostname: "localhost",
working_dir: "/home/user",
py_interpreter_path: null,
prefix: "",
suffix: "",
},
},
});
describe("usePlanPreviewEvents", () => {
it("should return empty set when no events provided", () => {
const { result } = renderHook(() => usePlanPreviewEvents([]));
expect(result.current).toBeInstanceOf(Set);
expect(result.current.size).toBe(0);
});
it("should return empty set when no PlanningFileEditorObservation events exist", () => {
const events: OpenHandsEvent[] = [
createUserMessageEvent("user-1"),
createOtherObservationEvent("obs-1"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
expect(result.current.size).toBe(0);
});
it("should return event ID for single PlanningFileEditorObservation in one phase", () => {
const events: OpenHandsEvent[] = [
createUserMessageEvent("user-1"),
createPlanningObservationEvent("plan-obs-1"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
expect(result.current.size).toBe(1);
expect(result.current.has("plan-obs-1")).toBe(true);
});
it("should return only the last PlanningFileEditorObservation when multiple exist in one phase", () => {
const events: OpenHandsEvent[] = [
createUserMessageEvent("user-1"),
createPlanningObservationEvent("plan-obs-1"),
createPlanningObservationEvent("plan-obs-2"),
createPlanningObservationEvent("plan-obs-3"),
createOtherObservationEvent("other-obs-1"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
// Should only include the last one in the phase
expect(result.current.size).toBe(1);
expect(result.current.has("plan-obs-1")).toBe(false);
expect(result.current.has("plan-obs-2")).toBe(false);
expect(result.current.has("plan-obs-3")).toBe(true);
});
it("should return one event ID per phase when multiple phases exist", () => {
const events: OpenHandsEvent[] = [
createUserMessageEvent("user-1"),
createPlanningObservationEvent("plan-obs-1"),
createPlanningObservationEvent("plan-obs-2"),
createUserMessageEvent("user-2"),
createPlanningObservationEvent("plan-obs-3"),
createPlanningObservationEvent("plan-obs-4"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
// Should have one preview per phase (last observation in each phase)
expect(result.current.size).toBe(2);
expect(result.current.has("plan-obs-2")).toBe(true); // Last in phase 1
expect(result.current.has("plan-obs-4")).toBe(true); // Last in phase 2
expect(result.current.has("plan-obs-1")).toBe(false);
expect(result.current.has("plan-obs-3")).toBe(false);
});
it("should handle phase with no PlanningFileEditorObservation", () => {
const events: OpenHandsEvent[] = [
createUserMessageEvent("user-1"),
createOtherObservationEvent("obs-1"),
createUserMessageEvent("user-2"),
createPlanningObservationEvent("plan-obs-1"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
// Only phase 2 has a planning observation
expect(result.current.size).toBe(1);
expect(result.current.has("plan-obs-1")).toBe(true);
});
it("should handle events starting with non-user message", () => {
const events: OpenHandsEvent[] = [
createOtherObservationEvent("obs-1"),
createUserMessageEvent("user-1"),
createPlanningObservationEvent("plan-obs-1"),
];
const { result } = renderHook(() => usePlanPreviewEvents(events));
// Events before first user message should be in first phase
expect(result.current.size).toBe(1);
expect(result.current.has("plan-obs-1")).toBe(true);
});
});
describe("shouldShowPlanPreview", () => {
it("should return true when event ID is in the set", () => {
const planPreviewEventIds = new Set(["event-1", "event-2", "event-3"]);
expect(shouldShowPlanPreview("event-2", planPreviewEventIds)).toBe(true);
});
it("should return false when event ID is not in the set", () => {
const planPreviewEventIds = new Set(["event-1", "event-2"]);
expect(shouldShowPlanPreview("event-3", planPreviewEventIds)).toBe(false);
});
it("should return false when set is empty", () => {
const planPreviewEventIds = new Set<string>();
expect(shouldShowPlanPreview("event-1", planPreviewEventIds)).toBe(false);
});
});
@@ -40,6 +40,18 @@ import { conversationWebSocketTestSetup } from "./helpers/msw-websocket-setup";
import { useEventStore } from "#/stores/use-event-store";
import { isV1Event } from "#/types/v1/type-guards";
// Mock useUserConversation to return V1 conversation data
vi.mock("#/hooks/query/use-user-conversation", () => ({
useUserConversation: vi.fn(() => ({
data: {
conversation_version: "V1",
status: "RUNNING",
},
isLoading: false,
error: null,
})),
}));
// MSW WebSocket mock setup
const { wsLink, server: mswServer } = conversationWebSocketTestSetup();
@@ -667,6 +679,16 @@ describe("Conversation WebSocket Handler", () => {
// Set up MSW to mock both the HTTP API and WebSocket connection
mswServer.use(
// Mock events search for history preloading
http.get(
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
async () => {
await new Promise((resolve) => setTimeout(resolve, 10));
return HttpResponse.json({
items: mockHistoryEvents,
});
},
),
http.get(
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
() => HttpResponse.json(expectedEventCount),
@@ -703,11 +725,6 @@ describe("Conversation WebSocket Handler", () => {
`http://localhost:3000/api/conversations/${conversationId}`,
);
// Initially should be loading history
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
"true",
);
// Wait for all events to be received
await waitFor(() => {
expect(screen.getByTestId("events-received")).toHaveTextContent("3");
@@ -726,6 +743,14 @@ describe("Conversation WebSocket Handler", () => {
// Set up MSW to mock both the HTTP API and WebSocket connection
mswServer.use(
// Mock empty events search
http.get(
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
() =>
HttpResponse.json({
items: [],
}),
),
http.get(
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
() => HttpResponse.json(0),
@@ -775,6 +800,16 @@ describe("Conversation WebSocket Handler", () => {
// Set up MSW to mock both the HTTP API and WebSocket connection
mswServer.use(
// Mock events search for history preloading (50 events)
http.get(
`http://localhost:3000/api/v1/conversation/${conversationId}/events/search`,
async () => {
await new Promise((resolve) => setTimeout(resolve, 10));
return HttpResponse.json({
items: mockHistoryEvents,
});
},
),
http.get(
`http://localhost:3000/api/conversations/${conversationId}/events/count`,
() => HttpResponse.json(expectedEventCount),
@@ -810,11 +845,6 @@ describe("Conversation WebSocket Handler", () => {
`http://localhost:3000/api/conversations/${conversationId}`,
);
// Initially should be loading history
expect(screen.getByTestId("is-loading-history")).toHaveTextContent(
"true",
);
// Wait for all events to be received
await waitFor(() => {
expect(screen.getByTestId("events-received")).toHaveTextContent("50");
@@ -0,0 +1,114 @@
import { describe, it, expect, afterEach, vi } from "vitest";
import React from "react";
import { renderHook, waitFor } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { useConversationHistory } from "#/hooks/query/use-conversation-history";
import EventService from "#/api/event-service/event-service.api";
import { useUserConversation } from "#/hooks/query/use-user-conversation";
import type { Conversation } from "#/api/open-hands.types";
import type { OpenHandsEvent } from "#/types/v1/core";
function makeConversation(version: "V0" | "V1"): Conversation {
return {
conversation_id: "conv-test",
title: "Test Conversation",
selected_repository: null,
selected_branch: null,
git_provider: null,
last_updated_at: new Date().toISOString(),
created_at: new Date().toISOString(),
status: "RUNNING",
runtime_status: null,
url: null,
session_api_key: null,
conversation_version: version,
};
}
function makeEvent(): OpenHandsEvent {
return {
id: "evt-1",
} as OpenHandsEvent;
}
// --------------------
// Mocks
// --------------------
vi.mock("#/api/open-hands-axios", () => ({
openHands: {
get: vi.fn(),
},
}));
vi.mock("#/api/event-service/event-service.api");
vi.mock("#/hooks/query/use-user-conversation");
const queryClient = new QueryClient();
function wrapper({ children }: { children: React.ReactNode }) {
return (
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
);
}
// --------------------
// Tests
// --------------------
describe("useConversationHistory", () => {
afterEach(() => {
vi.clearAllMocks();
});
it("calls V1 REST endpoint for V1 conversations", async () => {
const v1SearchEventsSpy = vi.spyOn(EventService, "searchEventsV1");
vi.mocked(useUserConversation).mockReturnValue({
data: makeConversation("V1"),
isLoading: false,
isPending: false,
isError: false,
error: null,
refetch: vi.fn(),
} as any);
v1SearchEventsSpy.mockResolvedValue([makeEvent()]);
const { result } = renderHook(() => useConversationHistory("conv-123"), {
wrapper,
});
await waitFor(() => {
expect(result.current.data).toBeDefined();
});
expect(EventService.searchEventsV1).toHaveBeenCalledWith("conv-123");
expect(EventService.searchEventsV0).not.toHaveBeenCalled();
});
it("calls V0 REST endpoint for V0 conversations", async () => {
const v0SearchEventsSpy = vi.spyOn(EventService, "searchEventsV0");
vi.mocked(useUserConversation).mockReturnValue({
data: makeConversation("V0"),
isLoading: false,
isPending: false,
isError: false,
error: null,
refetch: vi.fn(),
} as any);
v0SearchEventsSpy.mockResolvedValue([makeEvent()]);
const { result } = renderHook(() => useConversationHistory("conv-456"), {
wrapper,
});
await waitFor(() => {
expect(result.current.data).toBeDefined();
});
expect(EventService.searchEventsV0).toHaveBeenCalledWith("conv-456");
expect(EventService.searchEventsV1).not.toHaveBeenCalled();
});
});
+61 -8
View File
@@ -7,14 +7,19 @@ import LoginPage from "#/routes/login";
import OptionService from "#/api/option-service/option-service.api";
import AuthService from "#/api/auth-service/auth-service.api";
const { useEmailVerificationMock } = vi.hoisted(() => ({
useEmailVerificationMock: vi.fn(() => ({
emailVerified: false,
hasDuplicatedEmail: false,
emailVerificationModalOpen: false,
setEmailVerificationModalOpen: vi.fn(),
})),
}));
const { useEmailVerificationMock, resendEmailVerificationMock } = vi.hoisted(
() => ({
useEmailVerificationMock: vi.fn(() => ({
emailVerified: false,
hasDuplicatedEmail: false,
emailVerificationModalOpen: false,
setEmailVerificationModalOpen: vi.fn(),
userId: null as string | null,
resendEmailVerification: vi.fn(),
})),
resendEmailVerificationMock: vi.fn(),
}),
);
vi.mock("#/hooks/use-github-auth-url", () => ({
useGitHubAuthUrl: () => "https://github.com/login/oauth/authorize",
@@ -348,6 +353,8 @@ describe("LoginPage", () => {
hasDuplicatedEmail: false,
emailVerificationModalOpen: false,
setEmailVerificationModalOpen: vi.fn(),
userId: null,
resendEmailVerification: resendEmailVerificationMock,
});
render(<RouterStub initialEntries={["/login"]} />, {
@@ -367,6 +374,8 @@ describe("LoginPage", () => {
hasDuplicatedEmail: true,
emailVerificationModalOpen: false,
setEmailVerificationModalOpen: vi.fn(),
userId: null,
resendEmailVerification: resendEmailVerificationMock,
});
render(<RouterStub initialEntries={["/login"]} />, {
@@ -379,6 +388,41 @@ describe("LoginPage", () => {
).toBeInTheDocument();
});
});
it("should pass userId to EmailVerificationModal when userId is provided", async () => {
const user = userEvent.setup();
const testUserId = "test-user-id-123";
const setEmailVerificationModalOpen = vi.fn();
useEmailVerificationMock.mockReturnValue({
emailVerified: false,
hasDuplicatedEmail: false,
emailVerificationModalOpen: true,
setEmailVerificationModalOpen,
userId: testUserId,
resendEmailVerification: resendEmailVerificationMock,
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(
screen.getByText("AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY"),
).toBeInTheDocument();
});
const resendButton = screen.getByRole("button", {
name: /SETTINGS\$RESEND_VERIFICATION/i,
});
await user.click(resendButton);
expect(resendEmailVerificationMock).toHaveBeenCalledWith({
userId: testUserId,
isAuthFlow: true,
});
});
});
describe("Loading States", () => {
@@ -415,6 +459,15 @@ describe("LoginPage", () => {
describe("Terms and Privacy", () => {
it("should display Terms and Privacy notice", async () => {
useEmailVerificationMock.mockReturnValue({
emailVerified: false,
hasDuplicatedEmail: false,
emailVerificationModalOpen: false,
setEmailVerificationModalOpen: vi.fn(),
userId: null as string | null,
resendEmailVerification: resendEmailVerificationMock,
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
@@ -48,6 +48,7 @@ function LoginStub() {
searchParams.get("email_verification_required") === "true";
const emailVerified = searchParams.get("email_verified") === "true";
const emailVerificationText = "AUTH$PLEASE_CHECK_EMAIL_TO_VERIFY";
const returnTo = searchParams.get("returnTo");
return (
<div data-testid="login-page">
@@ -58,6 +59,7 @@ function LoginStub() {
{emailVerificationText}
</div>
)}
{returnTo && <div data-testid="return-to-param">{returnTo}</div>}
</div>
</div>
);
@@ -100,6 +102,27 @@ const RouterStubWithLogin = createRoutesStub([
},
]);
const RouterStubWithDeviceVerify = createRoutesStub([
{
Component: MainApp,
path: "/",
children: [
{
Component: () => <div data-testid="outlet-content" />,
path: "/",
},
{
Component: () => <div data-testid="device-verify-page" />,
path: "/oauth/device/verify",
},
],
},
{
Component: LoginStub,
path: "/login",
},
]);
const renderMainApp = (initialEntries: string[] = ["/"]) =>
render(<RouterStub initialEntries={initialEntries} />, {
wrapper: ({ children }) => (
@@ -311,5 +334,23 @@ describe("MainApp", () => {
{ timeout: 2000 },
);
});
it("should preserve query parameters in returnTo when redirecting to login", async () => {
renderWithLoginStub(RouterStubWithDeviceVerify, [
"/oauth/device/verify?user_code=F9XN6BKU",
]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
const returnToElement = screen.getByTestId("return-to-param");
expect(returnToElement).toBeInTheDocument();
expect(returnToElement.textContent).toBe(
"/oauth/device/verify?user_code=F9XN6BKU",
);
},
{ timeout: 2000 },
);
});
});
});
@@ -138,4 +138,72 @@ describe("handleEventForUI", () => {
anotherActionEvent,
]);
});
it("should NOT replace ThinkAction with ThinkObservation", () => {
const mockThinkAction: ActionEvent = {
id: "test-think-action-1",
timestamp: Date.now().toString(),
source: "agent",
thought: [{ type: "text", text: "I am thinking..." }],
thinking_blocks: [],
action: {
kind: "ThinkAction",
thought: "I am thinking...",
},
tool_name: "think",
tool_call_id: "call_think_1",
tool_call: {
id: "call_think_1",
type: "function",
function: {
name: "think",
arguments: "",
},
},
llm_response_id: "response_think",
security_risk: SecurityRisk.UNKNOWN,
};
const mockThinkObservation: ObservationEvent = {
id: "test-think-observation-1",
timestamp: Date.now().toString(),
source: "environment",
tool_name: "think",
tool_call_id: "call_think_1",
observation: {
kind: "ThinkObservation",
content: [{ type: "text", text: "Your thought has been logged." }],
},
action_id: "test-think-action-1",
};
const initialUiEvents = [mockMessageEvent, mockThinkAction];
const result = handleEventForUI(mockThinkObservation, initialUiEvents);
// ThinkObservation should NOT be added - ThinkAction should remain
expect(result).toEqual([mockMessageEvent, mockThinkAction]);
expect(result).not.toBe(initialUiEvents);
});
it("should NOT add ThinkObservation even when ThinkAction is not found", () => {
const mockThinkObservation: ObservationEvent = {
id: "test-think-observation-1",
timestamp: Date.now().toString(),
source: "environment",
tool_name: "think",
tool_call_id: "call_think_1",
observation: {
kind: "ThinkObservation",
content: [{ type: "text", text: "Your thought has been logged." }],
},
action_id: "test-think-action-not-found",
};
const initialUiEvents = [mockMessageEvent];
const result = handleEventForUI(mockThinkObservation, initialUiEvents);
// ThinkObservation should never be added to uiEvents
expect(result).toEqual([mockMessageEvent]);
expect(result).not.toBe(initialUiEvents);
});
});
@@ -103,7 +103,7 @@ export interface V1AppConversation {
export interface Skill {
name: string;
type: "repo" | "knowledge";
type: "repo" | "knowledge" | "agentskills";
content: string;
triggers: string[];
}
@@ -5,6 +5,8 @@ import type {
ConfirmationResponseRequest,
ConfirmationResponseResponse,
} from "./event-service.types";
import { openHands } from "../open-hands-axios";
import { OpenHandsEvent } from "#/types/v1/core";
class EventService {
/**
@@ -61,5 +63,27 @@ class EventService {
);
return data;
}
// V1 conversations — App Server REST endpoint
static async searchEventsV1(conversationId: string, limit = 100) {
const { data } = await openHands.get<{
items: OpenHandsEvent[];
}>(`/api/v1/conversation/${conversationId}/events/search`, {
params: { limit },
});
return data.items;
}
// V0 conversations — Legacy REST endpoint
static async searchEventsV0(conversationId: string, limit = 100) {
const { data } = await openHands.get<{
events: OpenHandsEvent[];
}>(`/api/conversations/${conversationId}/events`, {
params: { limit },
});
return data.events;
}
}
export default EventService;
+1 -1
View File
@@ -110,7 +110,7 @@ export interface InputMetadata {
export interface Microagent {
name: string;
type: "repo" | "knowledge";
type: "repo" | "knowledge" | "agentskills";
content: string;
triggers: string[];
}
@@ -1,22 +1,24 @@
import { useMemo } from "react";
import { useTranslation } from "react-i18next";
import { ArrowUpRight } from "lucide-react";
import LessonPlanIcon from "#/icons/lesson-plan.svg?react";
import { USE_PLANNING_AGENT } from "#/utils/feature-flags";
import { Typography } from "#/ui/typography";
import { I18nKey } from "#/i18n/declaration";
import { MarkdownRenderer } from "#/components/features/markdown/markdown-renderer";
const MAX_CONTENT_LENGTH = 300;
interface PlanPreviewProps {
title?: string;
description?: string;
/** Raw plan content from PLAN.md file */
planContent?: string | null;
onViewClick?: () => void;
onBuildClick?: () => void;
}
// TODO: Remove the hardcoded values and use the plan content from the conversation store
/* eslint-disable i18next/no-literal-string */
export function PlanPreview({
title = "Improve Developer Onboarding and Examples",
description = "Based on the analysis of Browser-Use's current documentation and examples, this plan addresses gaps in developer onboarding by creating a progressive learning path, troubleshooting resources, and practical examples that address real-world scenarios (like the LM Studio/local LLM integration issues encountered...",
planContent,
onViewClick,
onBuildClick,
}: PlanPreviewProps) {
@@ -24,6 +26,13 @@ export function PlanPreview({
const shouldUsePlanningAgent = USE_PLANNING_AGENT();
// Truncate plan content for preview
const truncatedContent = useMemo(() => {
if (!planContent) return "";
if (planContent.length <= MAX_CONTENT_LENGTH) return planContent;
return `${planContent.slice(0, MAX_CONTENT_LENGTH)}...`;
}, [planContent]);
if (!shouldUsePlanningAgent) {
return null;
}
@@ -41,6 +50,7 @@ export function PlanPreview({
type="button"
onClick={onViewClick}
className="flex items-center gap-1 hover:opacity-80 transition-opacity"
data-testid="plan-preview-view-button"
>
<Typography.Text className="font-medium text-[11px] text-white tracking-[0.11px] leading-4">
{t(I18nKey.COMMON$VIEW)}
@@ -50,16 +60,27 @@ export function PlanPreview({
</div>
{/* Content */}
<div className="flex flex-col gap-[10px] p-4">
<h3 className="font-bold text-[19px] text-white leading-[29px]">
{title}
</h3>
<p className="text-[15px] text-white leading-[29px]">
{description}
<Typography.Text className="text-[#4a67bd] cursor-pointer hover:underline ml-1">
{t(I18nKey.COMMON$READ_MORE)}
</Typography.Text>
</p>
<div
data-testid="plan-preview-content"
className="flex flex-col gap-[10px] p-4 text-[15px] text-white leading-[29px]"
>
{truncatedContent && (
<>
<MarkdownRenderer includeStandard includeHeadings>
{truncatedContent}
</MarkdownRenderer>
{planContent && planContent.length > MAX_CONTENT_LENGTH && (
<button
type="button"
onClick={onViewClick}
className="text-[#4a67bd] cursor-pointer hover:underline text-left"
data-testid="plan-preview-read-more-button"
>
{t(I18nKey.COMMON$READ_MORE)}
</button>
)}
</>
)}
</div>
{/* Footer */}
@@ -68,6 +89,7 @@ export function PlanPreview({
type="button"
onClick={onBuildClick}
className="bg-white flex items-center justify-center h-[26px] px-2 rounded-[4px] w-[93px] hover:opacity-90 transition-opacity cursor-pointer"
data-testid="plan-preview-build-button"
>
<Typography.Text className="font-medium text-[14px] text-black leading-5">
{t(I18nKey.COMMON$BUILD)}{" "}
@@ -11,6 +11,15 @@ interface SkillItemProps {
}
export function SkillItem({ skill, isExpanded, onToggle }: SkillItemProps) {
let skillTypeLabel: string;
if (skill.type === "repo") {
skillTypeLabel = "Repository";
} else if (skill.type === "knowledge") {
skillTypeLabel = "Knowledge";
} else {
skillTypeLabel = "AgentSkills";
}
return (
<div className="rounded-md overflow-hidden">
<button
@@ -25,7 +34,7 @@ export function SkillItem({ skill, isExpanded, onToggle }: SkillItemProps) {
</div>
<div className="flex items-center">
<Typography.Text className="px-2 py-1 text-xs rounded-full bg-gray-800 mr-2">
{skill.type === "repo" ? "Repository" : "Knowledge"}
{skillTypeLabel}
</Typography.Text>
<Typography.Text className="text-gray-300">
{isExpanded ? (
@@ -87,16 +87,6 @@ export function GitBranchDropdown({
[onBranchSelect],
);
// Handle input value change
const handleInputValueChange = useCallback(
({ inputValue: newInputValue }: { inputValue?: string }) => {
if (newInputValue !== undefined) {
setInputValue(newInputValue);
}
},
[],
);
// Handle menu scroll for infinite loading
const handleMenuScroll = useCallback(
(event: React.UIEvent<HTMLUListElement>) => {
@@ -128,8 +118,14 @@ export function GitBranchDropdown({
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
handleBranchSelect(newSelectedItem || null);
},
onInputValueChange: handleInputValueChange,
inputValue,
// Override Downshift's default input-click behavior to avoid closing/reopening
// the menu, which would reset scroll position and break search continuity.
stateReducer: (state, actionAndChanges) =>
actionAndChanges.type === useCombobox.stateChangeTypes.InputClick &&
state.isOpen
? { ...actionAndChanges.changes, isOpen: true }
: actionAndChanges.changes,
});
// Reset branch selection when repository changes
@@ -176,12 +172,12 @@ export function GitBranchDropdown({
// Initialize input value when selectedBranch changes (but not when user is typing)
useEffect(() => {
if (selectedBranch && !isOpen && inputValue !== selectedBranch.name) {
if (selectedBranch && !isOpen) {
setInputValue(selectedBranch.name);
} else if (!selectedBranch && !isOpen && inputValue) {
} else if (!selectedBranch && !isOpen) {
setInputValue("");
}
}, [selectedBranch, isOpen, inputValue]);
}, [selectedBranch, isOpen]);
const isLoadingState = isLoading || isSearchLoading || isFetchingNextPage;
@@ -207,6 +203,10 @@ export function GitBranchDropdown({
"disabled:bg-[#363636] disabled:cursor-not-allowed disabled:opacity-60",
"pl-7 pr-16 text-sm font-normal leading-5", // Space for clear and toggle buttons
),
// Direct onChange for cursor position preservation
onChange: (e: React.ChangeEvent<HTMLInputElement>) => {
setInputValue(e.target.value);
},
})}
data-testid="git-branch-dropdown-input"
/>
@@ -184,14 +184,6 @@ export function GitRepoDropdown({
setInputValue("");
}, [handleSelectionChange]);
// Handle input value change
const handleInputValueChange = useCallback(
({ inputValue: newInputValue }: { inputValue?: string }) => {
setInputValue(newInputValue || "");
},
[],
);
// Handle scroll to bottom for pagination
const handleMenuScroll = useCallback(
(event: React.UIEvent<HTMLUListElement>) => {
@@ -220,8 +212,14 @@ export function GitRepoDropdown({
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
handleSelectionChange(newSelectedItem);
},
onInputValueChange: handleInputValueChange,
inputValue,
// Override Downshift's default input-click behavior to avoid closing/reopening
// the menu, which would reset scroll position and break search continuity.
stateReducer: (state, actionAndChanges) =>
actionAndChanges.type === useCombobox.stateChangeTypes.InputClick &&
state.isOpen
? { ...actionAndChanges.changes, isOpen: true }
: actionAndChanges.changes,
});
// Sync localSelectedItem with external value prop
@@ -237,6 +235,8 @@ export function GitRepoDropdown({
useEffect(() => {
if (selectedRepository && !isOpen) {
setInputValue(selectedRepository.full_name);
} else if (!selectedRepository && !isOpen) {
setInputValue("");
}
}, [selectedRepository, isOpen]);
@@ -335,6 +335,10 @@ export function GitRepoDropdown({
"disabled:bg-[#363636] disabled:cursor-not-allowed disabled:opacity-60",
"pl-7 pr-16 text-sm font-normal leading-5", // Space for clear and toggle buttons
),
// Direct onChange for cursor position preservation
onChange: (e: React.ChangeEvent<HTMLInputElement>) => {
setInputValue(e.target.value);
},
})}
data-testid="git-repo-dropdown"
/>
@@ -27,6 +27,11 @@ export const shouldRenderEvent = (event: OpenHandsEvent) => {
return false;
}
// Hide PlanningFileEditorAction - handled separately with PlanPreview component
if (actionType === "PlanningFileEditorAction") {
return false;
}
return true;
}
@@ -6,9 +6,11 @@ import {
isObservationEvent,
isAgentErrorEvent,
isUserMessageEvent,
isPlanningFileEditorObservationEvent,
} from "#/types/v1/type-guards";
import { MicroagentStatus } from "#/types/microagent-status";
import { useConfig } from "#/hooks/query/use-config";
import { useConversationStore } from "#/stores/conversation-store";
// TODO: Implement V1 feedback functionality when API supports V1 event IDs
// import { useFeedbackExists } from "#/hooks/query/use-feedback-exists";
import {
@@ -19,6 +21,8 @@ import {
ThoughtEventMessage,
} from "./event-message-components";
import { createSkillReadyEvent } from "./event-content-helpers/create-skill-ready-event";
import { PlanPreview } from "../../features/chat/plan-preview";
import { shouldShowPlanPreview } from "./hooks/use-plan-preview-events";
interface EventMessageProps {
event: OpenHandsEvent & { isFromPlanningAgent?: boolean };
@@ -33,6 +37,8 @@ interface EventMessageProps {
tooltip?: string;
}>;
isInLast10Actions: boolean;
/** Set of event IDs that should render PlanPreview (one per user message phase) */
planPreviewEventIds?: Set<string>;
}
/**
@@ -143,8 +149,10 @@ export function EventMessage({
microagentPRUrl,
actions,
isInLast10Actions,
planPreviewEventIds,
}: EventMessageProps) {
const { data: config } = useConfig();
const { planContent } = useConversationStore();
// V1 events use string IDs, but useFeedbackExists expects number
// For now, we'll skip feedback functionality for V1 events
@@ -198,6 +206,21 @@ export function EventMessage({
// Observation events - find the corresponding action and render thought + observation
if (isObservationEvent(event)) {
// Handle PlanningFileEditorObservation specially
if (isPlanningFileEditorObservationEvent(event)) {
// Only show PlanPreview if this event is marked as the one to display
// (last PlanningFileEditorObservation in its phase)
if (
planPreviewEventIds &&
shouldShowPlanPreview(event.id, planPreviewEventIds)
) {
return <PlanPreview planContent={planContent} />;
}
// Not the designated preview event for this phase - render nothing
// This prevents duplicate previews within the same phase
return null;
}
// Find the action that this observation is responding to
const correspondingAction = messages.find(
(msg) => isActionEvent(msg) && msg.id === event.action_id,
@@ -0,0 +1,114 @@
import { useMemo } from "react";
import { OpenHandsEvent } from "#/types/v1/core";
import {
isUserMessageEvent,
isPlanningFileEditorObservationEvent,
} from "#/types/v1/type-guards";
/**
* Groups events into phases based on user messages.
* A phase starts with a user message and includes all subsequent events
* until the next user message.
*
* @param events - The full list of events
* @returns Array of phases, where each phase is an array of events
*/
function groupEventsByPhase(events: OpenHandsEvent[]): OpenHandsEvent[][] {
const phases: OpenHandsEvent[][] = [];
let currentPhase: OpenHandsEvent[] = [];
for (const event of events) {
if (isUserMessageEvent(event)) {
// Start a new phase with the user message
if (currentPhase.length > 0) {
phases.push(currentPhase);
}
currentPhase = [event];
} else {
// Add event to current phase
currentPhase.push(event);
}
}
// Don't forget the last phase
if (currentPhase.length > 0) {
phases.push(currentPhase);
}
return phases;
}
/**
* Finds the last PlanningFileEditorObservation in a phase.
*
* @param phase - Array of events in a phase
* @returns The event ID of the last PlanningFileEditorObservation, or null
*/
function findLastPlanningObservationInPhase(
phase: OpenHandsEvent[],
): string | null {
// Iterate backwards to find the last one
for (let i = phase.length - 1; i >= 0; i -= 1) {
const event = phase[i];
if (isPlanningFileEditorObservationEvent(event)) {
return event.id;
}
}
return null;
}
export interface PlanPreviewEventInfo {
eventId: string;
/** Index of this plan preview in the conversation (1st, 2nd, etc.) */
phaseIndex: number;
}
/**
* Hook to determine which PlanningFileEditorObservation events should render PlanPreview.
*
* This hook implements phase-based grouping where:
* - A phase starts with a user message and ends at the next user message
* - Only the LAST PlanningFileEditorObservation in each phase shows PlanPreview
* - This ensures only one preview per user request, even with multiple observations
*
* Scenario handling:
* - Scenario 1 (Create plan): Multiple observations in one phase → 1 preview
* - Scenario 2 (Create then update): Two user messages → two phases → 2 previews
* - Scenario 3 (Create + update while processing): Two user messages → 2 previews
*
* @param allEvents - Full list of v1 events (for phase detection)
* @returns Set of event IDs that should render PlanPreview
*/
export function usePlanPreviewEvents(allEvents: OpenHandsEvent[]): Set<string> {
return useMemo(() => {
const planPreviewEventIds = new Set<string>();
// Group events by phases (user message boundaries)
const phases = groupEventsByPhase(allEvents);
// For each phase, find the last PlanningFileEditorObservation
phases.forEach((phase) => {
const lastPlanningObservationId =
findLastPlanningObservationInPhase(phase);
if (lastPlanningObservationId) {
planPreviewEventIds.add(lastPlanningObservationId);
}
});
return planPreviewEventIds;
}, [allEvents]);
}
/**
* Check if a specific event should render PlanPreview.
*
* @param eventId - The event ID to check
* @param planPreviewEventIds - Set of event IDs that should render PlanPreview
* @returns true if this event should render PlanPreview
*/
export function shouldShowPlanPreview(
eventId: string,
planPreviewEventIds: Set<string>,
): boolean {
return planPreviewEventIds.has(eventId);
}
@@ -3,6 +3,7 @@ import { OpenHandsEvent } from "#/types/v1/core";
import { EventMessage } from "./event-message";
import { ChatMessage } from "../../features/chat/chat-message";
import { useOptimisticUserMessageStore } from "#/stores/optimistic-user-message-store";
import { usePlanPreviewEvents } from "./hooks/use-plan-preview-events";
// TODO: Implement microagent functionality for V1 when APIs support V1 event IDs
// import { AgentState } from "#/types/agent-state";
// import MemoryIcon from "#/icons/memory_icon.svg?react";
@@ -18,6 +19,10 @@ export const Messages: React.FC<MessagesProps> = React.memo(
const optimisticUserMessage = getOptimisticUserMessage();
// Get the set of event IDs that should render PlanPreview
// This ensures only one preview per user message "phase"
const planPreviewEventIds = usePlanPreviewEvents(allEvents);
// TODO: Implement microagent functionality for V1 if needed
// For now, we'll skip microagent features
@@ -30,6 +35,7 @@ export const Messages: React.FC<MessagesProps> = React.memo(
messages={allEvents}
isLastMessage={messages.length - 1 === index}
isInLast10Actions={messages.length - 1 - index < 10}
planPreviewEventIds={planPreviewEventIds}
// Microagent props - not implemented yet for V1
// microagentStatus={undefined}
// microagentConversationId={undefined}
@@ -46,6 +46,7 @@ import { useTracking } from "#/hooks/use-tracking";
import { useReadConversationFile } from "#/hooks/mutation/use-read-conversation-file";
import useMetricsStore from "#/stores/metrics-store";
import { I18nKey } from "#/i18n/declaration";
import { useConversationHistory } from "#/hooks/query/use-conversation-history";
// eslint-disable-next-line @typescript-eslint/naming-convention
export type V1_WebSocketConnectionState =
@@ -306,6 +307,21 @@ export function ConversationWebSocketProvider({
latestPlanningFileEventRef.current = null;
}, [conversationId]);
const { data: preloadedEvents } = useConversationHistory(conversationId);
useEffect(() => {
if (!preloadedEvents || preloadedEvents.length === 0) {
setIsLoadingHistoryMain(false);
return;
}
for (const event of preloadedEvents) {
addEvent(event);
}
setIsLoadingHistoryMain(false);
}, [preloadedEvents, addEvent]);
// Separate message handlers for each connection
const handleMainMessage = useCallback(
(messageEvent: MessageEvent) => {
@@ -0,0 +1,22 @@
import { useQuery } from "@tanstack/react-query";
import EventService from "#/api/event-service/event-service.api";
import { useUserConversation } from "#/hooks/query/use-user-conversation";
export const useConversationHistory = (conversationId?: string) => {
const { data: conversation } = useUserConversation(conversationId ?? null);
return useQuery({
queryKey: ["conversation-history", conversationId, conversation],
enabled: !!conversationId && !!conversation,
queryFn: async () => {
if (!conversationId || !conversation) return [];
if (conversation.conversation_version === "V1") {
return EventService.searchEventsV1(conversationId);
}
return EventService.searchEventsV0(conversationId);
},
staleTime: 30_000,
});
};
+2
View File
@@ -20,6 +20,7 @@ export default function LoginPage() {
recaptchaBlocked,
emailVerificationModalOpen,
setEmailVerificationModalOpen,
userId,
} = useEmailVerification();
const gitHubAuthUrl = useGitHubAuthUrl({
@@ -77,6 +78,7 @@ export default function LoginPage() {
onClose={() => {
setEmailVerificationModalOpen(false);
}}
userId={userId}
/>
)}
</>
+11 -4
View File
@@ -5,6 +5,7 @@ import {
Outlet,
useNavigate,
useLocation,
useSearchParams,
} from "react-router";
import { useTranslation } from "react-i18next";
import { I18nKey } from "#/i18n/declaration";
@@ -67,6 +68,7 @@ export default function MainApp() {
const appTitle = useAppTitle();
const navigate = useNavigate();
const { pathname } = useLocation();
const [searchParams] = useSearchParams();
const isOnTosPage = useIsOnTosPage();
const { data: settings } = useSettings();
const { migrateUserConsent } = useMigrateUserConsent();
@@ -182,13 +184,18 @@ export default function MainApp() {
React.useEffect(() => {
if (shouldRedirectToLogin) {
const returnTo = pathname !== "/" ? pathname : "";
const loginUrl = returnTo
? `/login?returnTo=${encodeURIComponent(returnTo)}`
// Include search params in returnTo to preserve query string (e.g., user_code for device OAuth)
const searchString = searchParams.toString();
let fullPath = "";
if (pathname !== "/") {
fullPath = searchString ? `${pathname}?${searchString}` : pathname;
}
const loginUrl = fullPath
? `/login?returnTo=${encodeURIComponent(fullPath)}`
: "/login";
navigate(loginUrl, { replace: true });
}
}, [shouldRedirectToLogin, pathname, navigate]);
}, [shouldRedirectToLogin, pathname, searchParams, navigate]);
if (shouldRedirectToLogin) {
return (
+32
View File
@@ -213,6 +213,37 @@ export interface BrowserCloseTabAction extends ActionBase<"BrowserCloseTabAction
tab_id: string;
}
export interface PlanningFileEditorAction extends ActionBase<"PlanningFileEditorAction"> {
/**
* The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.
*/
command: "view" | "create" | "str_replace" | "insert" | "undo_edit";
/**
* Absolute path to file (typically /workspace/project/PLAN.md).
*/
path: string;
/**
* Required parameter of `create` command, with the content of the file to be created.
*/
file_text: string | null;
/**
* Required parameter of `str_replace` command containing the string in `path` to replace.
*/
old_str: string | null;
/**
* Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.
*/
new_str: string | null;
/**
* Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`. Must be >= 1.
*/
insert_line: number | null;
/**
* Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown.
*/
view_range: [number, number] | null;
}
export type Action =
| MCPToolAction
| FinishAction
@@ -222,6 +253,7 @@ export type Action =
| FileEditorAction
| StrReplaceEditorAction
| TaskTrackerAction
| PlanningFileEditorAction
| BrowserNavigateAction
| BrowserClickAction
| BrowserTypeAction
+7 -1
View File
@@ -4,6 +4,7 @@ import { isObservationEvent } from "#/types/v1/type-guards";
/**
* Handles adding an event to the UI events array
* Replaces actions with observations when they arrive (so UI shows observation instead of action)
* Exception: ThinkAction is NOT replaced because the thought content is in the action, not in the observation
*/
export const handleEventForUI = (
event: OpenHandsEvent,
@@ -12,12 +13,17 @@ export const handleEventForUI = (
const newUiEvents = [...uiEvents];
if (isObservationEvent(event)) {
// Don't add ThinkObservation at all - we keep the ThinkAction instead
// The thought content is in the action, not the observation
if (event.observation.kind === "ThinkObservation") {
return newUiEvents;
}
// Find and replace the corresponding action from uiEvents
const actionIndex = newUiEvents.findIndex(
(uiEvent) => uiEvent.id === event.action_id,
);
if (actionIndex !== -1) {
// Replace the action with the observation
newUiEvents[actionIndex] = event;
} else {
// Action not found in uiEvents, just add the observation
+104
View File
@@ -7,6 +7,13 @@ import { I18nextProvider, initReactI18next } from "react-i18next";
import i18n from "i18next";
import { vi } from "vitest";
import { AxiosError } from "axios";
import {
ActionEvent,
MessageEvent,
ObservationEvent,
PlanningFileEditorObservation,
} from "#/types/v1/core";
import { SecurityRisk } from "#/types/v1/core";
export const useParamsMock = vi.fn(() => ({
conversationId: "test-conversation-id",
@@ -98,3 +105,100 @@ export const createAxiosError = (
config: {},
},
);
// Helper to create a PlanningFileEditorAction event
export const createPlanningFileEditorActionEvent = (
id: string,
): ActionEvent => ({
id,
timestamp: new Date().toISOString(),
source: "agent",
thought: [{ type: "text", text: "Planning action" }],
thinking_blocks: [],
action: {
kind: "PlanningFileEditorAction",
command: "create",
path: "/workspace/PLAN.md",
file_text: "Plan content",
old_str: null,
new_str: null,
insert_line: null,
view_range: null,
},
tool_name: "planning_file_editor",
tool_call_id: "call-1",
tool_call: {
id: "call-1",
type: "function",
function: {
name: "planning_file_editor",
arguments: '{"command": "create"}',
},
},
llm_response_id: "response-1",
security_risk: SecurityRisk.UNKNOWN,
});
// Helper to create a non-planning action event
export const createOtherActionEvent = (id: string): ActionEvent => ({
id,
timestamp: new Date().toISOString(),
source: "agent",
thought: [{ type: "text", text: "Other action" }],
thinking_blocks: [],
action: {
kind: "ExecuteBashAction",
command: "echo test",
is_input: false,
timeout: null,
reset: false,
},
tool_name: "execute_bash",
tool_call_id: "call-1",
tool_call: {
id: "call-1",
type: "function",
function: {
name: "execute_bash",
arguments: '{"command": "echo test"}',
},
},
llm_response_id: "response-1",
security_risk: SecurityRisk.UNKNOWN,
});
// Helper to create a PlanningFileEditorObservation event
export const createPlanningObservationEvent = (
id: string,
actionId: string = "action-1",
): ObservationEvent<PlanningFileEditorObservation> => ({
id,
timestamp: new Date().toISOString(),
source: "environment",
tool_name: "planning_file_editor",
tool_call_id: "call-1",
action_id: actionId,
observation: {
kind: "PlanningFileEditorObservation",
content: [{ type: "text", text: "Plan content" }],
is_error: false,
command: "create",
path: "/workspace/PLAN.md",
prev_exist: false,
old_content: null,
new_content: "Plan content",
},
});
// Helper to create a user message event
export const createUserMessageEvent = (id: string): MessageEvent => ({
id,
timestamp: new Date().toISOString(),
source: "user",
llm_message: {
role: "user",
content: [{ type: "text", text: "User message" }],
},
activated_microagents: [],
extended_content: [],
});
+2 -12
View File
@@ -1,18 +1,8 @@
# OpenHands
# OpenHands Architecture
This directory contains the core components of OpenHands.
## Documentation
- **[Architecture Documentation](./architecture/README.md)** - Detailed system architecture with Mermaid diagrams covering:
- System Architecture Overview
- Conversation Startup & WebSocket Flow
- Authentication Flow (Keycloak)
- Agent Execution & LLM Flow
- External Integrations (GitHub, Slack, Jira, etc.)
- Metrics, Logs & Observability
- **[External Architecture Docs](https://docs.openhands.dev/usage/architecture/backend)** - Official documentation (v0 backend architecture)
For an overview of the system architecture, see the [architecture documentation](https://docs.openhands.dev/usage/architecture/backend) (v0 backend architecture).
## Classes
@@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Literal
from typing import Any, Literal
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
@@ -14,6 +14,7 @@ from openhands.app_server.sandbox.sandbox_models import SandboxStatus
from openhands.integrations.service_types import ProviderType
from openhands.sdk.conversation.state import ConversationExecutionStatus
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.plugin import PluginSource
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
@@ -24,6 +25,45 @@ class AgentType(Enum):
PLAN = 'plan'
class PluginSpec(PluginSource):
"""Specification for loading a plugin into a conversation.
Extends SDK's PluginSource with user-provided plugin configuration parameters.
Inherits source, ref, and repo_path fields along with their validation.
"""
parameters: dict[str, Any] | None = Field(
default=None,
description='User-provided values for plugin input parameters',
)
@property
def display_name(self) -> str:
"""Extract a friendly display name from the plugin source.
Examples:
- 'github:owner/repo' -> 'repo'
- 'https://github.com/owner/repo.git' -> 'repo.git'
- '/local/path' -> 'path'
"""
return self.source.split('/')[-1] if '/' in self.source else self.source
def format_params_as_text(self, indent: str = '') -> str | None:
"""Format parameters as a readable text block for display.
Args:
indent: Optional prefix to add before each parameter line.
Returns:
Formatted parameters string, or None if no parameters.
"""
if not self.parameters:
return None
return '\n'.join(
f'{indent}- {key}: {value}' for key, value in self.parameters.items()
)
class AppConversationInfo(BaseModel):
"""Conversation info which does not contain status."""
@@ -118,6 +158,15 @@ class AppConversationStartRequest(OpenHandsModel):
public: bool | None = None
# Plugin parameters - for loading remote plugins into the conversation
plugins: list[PluginSpec] | None = Field(
default=None,
description=(
'List of plugins to load for this conversation. Plugins are loaded '
'and their skills/MCP config are merged into the agent.'
),
)
class AppConversationUpdateRequest(BaseModel):
public: bool
@@ -147,7 +196,8 @@ class AppConversationStartTask(OpenHandsModel):
Because starting an app conversation can be slow (And can involve starting a sandbox),
we kick off a background task for it. Once the conversation is started, the app_conversation_id
is populated."""
is populated.
"""
id: OpenHandsUUID = Field(default_factory=uuid4)
created_by_user_id: str | None
@@ -176,6 +226,6 @@ class SkillResponse(BaseModel):
"""Response model for skills endpoint."""
name: str
type: Literal['repo', 'knowledge']
type: Literal['repo', 'knowledge', 'agentskills']
content: str
triggers: list[str] = []
@@ -503,13 +503,6 @@ async def get_conversation_skills(
agent_server_url = replace_localhost_hostname_for_docker(agent_server_url)
# Create remote workspace
remote_workspace = AsyncRemoteWorkspace(
host=agent_server_url,
api_key=sandbox.session_api_key,
working_dir=sandbox_spec.working_dir,
)
# Load skills from all sources
logger.info(f'Loading skills for conversation {conversation_id}')
@@ -518,9 +511,9 @@ async def get_conversation_skills(
if isinstance(app_conversation_service, AppConversationServiceBase):
all_skills = await app_conversation_service.load_and_merge_all_skills(
sandbox,
remote_workspace,
conversation.selected_repository,
sandbox_spec.working_dir,
agent_server_url,
)
logger.info(
@@ -531,9 +524,11 @@ async def get_conversation_skills(
# Transform skills to response format
skills_response = []
for skill in all_skills:
# Determine type based on trigger
skill_type: Literal['repo', 'knowledge']
if skill.trigger is None:
# Determine type based on AgentSkills format and trigger
skill_type: Literal['repo', 'knowledge', 'agentskills']
if skill.is_agentskills_format:
skill_type = 'agentskills'
elif skill.trigger is None:
skill_type = 'repo'
else:
skill_type = 'knowledge'
@@ -626,7 +621,6 @@ async def _stream_app_conversation_start(
user_context: UserContext,
) -> AsyncGenerator[str, None]:
"""Stream a json list, item by item."""
# Because the original dependencies are closed after the method returns, we need
# a new dependency context which will continue intil the stream finishes.
state = InjectorState()
@@ -95,6 +95,7 @@ class AppConversationService(ABC):
task: AppConversationStartTask,
sandbox: SandboxInfo,
workspace: AsyncRemoteWorkspace,
agent_server_url: str,
) -> AsyncGenerator[AppConversationStartTask, None]:
"""Run the setup scripts for the project and yield status updates"""
yield task
@@ -104,7 +105,8 @@ class AppConversationService(ABC):
self, conversation_id: UUID, request: AppConversationUpdateRequest
) -> AppConversation | None:
"""Update an app conversation and return it. Return None if the conversation
did not exist."""
did not exist.
"""
@abstractmethod
async def delete_app_conversation(self, conversation_id: UUID) -> bool:
@@ -21,18 +21,16 @@ from openhands.app_server.app_conversation.app_conversation_service import (
AppConversationService,
)
from openhands.app_server.app_conversation.skill_loader import (
load_global_skills,
load_org_skills,
load_repo_skills,
load_sandbox_skills,
merge_skills,
build_org_config,
build_sandbox_config,
load_skills_from_agent_server,
)
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
from openhands.app_server.user.user_context import UserContext
from openhands.sdk import Agent
from openhands.sdk.context.agent_context import AgentContext
from openhands.sdk.context.condenser import LLMSummarizingCondenser
from openhands.sdk.context.skills import load_user_skills
from openhands.sdk.context.skills import Skill
from openhands.sdk.llm import LLM
from openhands.sdk.security.analyzer import SecurityAnalyzerBase
from openhands.sdk.security.confirmation_policy import (
@@ -53,7 +51,8 @@ PRE_COMMIT_LOCAL = '.git/hooks/pre-commit.local'
class AppConversationServiceBase(AppConversationService, ABC):
"""App Conversation service which adds git specific functionality.
Sets up repositories and installs hooks"""
Sets up repositories and installs hooks
"""
init_git_in_empty_workspace: bool
user_context: UserContext
@@ -61,67 +60,74 @@ class AppConversationServiceBase(AppConversationService, ABC):
async def load_and_merge_all_skills(
self,
sandbox: SandboxInfo,
remote_workspace: AsyncRemoteWorkspace,
selected_repository: str | None,
working_dir: str,
) -> list:
"""Load skills from all sources and merge them.
agent_server_url: str,
) -> list[Skill]:
"""Load skills from all sources via the agent-server.
This method handles all errors gracefully and will return an empty list
if skill loading fails completely.
This method calls the agent-server's /api/skills endpoint to load and
merge skills from all sources. The agent-server handles:
- Public skills (from OpenHands/skills GitHub repo)
- User skills (from ~/.openhands/skills/)
- Organization skills (from {org}/.openhands repo)
- Project/repo skills (from workspace .openhands/skills/)
- Sandbox skills (from exposed URLs)
Args:
remote_workspace: AsyncRemoteWorkspace for loading repo skills
sandbox: SandboxInfo containing exposed URLs and agent-server URL
selected_repository: Repository name or None
working_dir: Working directory path
agent_server_url: Agent-server URL (required)
Returns:
List of merged Skill objects from all sources, or empty list on failure
"""
try:
_logger.debug('Loading skills for V1 conversation')
_logger.debug('Loading skills for V1 conversation via agent-server')
# Load skills from all sources
sandbox_skills = load_sandbox_skills(sandbox)
global_skills = load_global_skills()
# Load user skills from ~/.openhands/skills/ directory
# Uses the SDK's load_user_skills() function which handles loading from
# ~/.openhands/skills/ and ~/.openhands/microagents/ (for backward compatibility)
try:
user_skills = load_user_skills()
_logger.info(
f'Loaded {len(user_skills)} user skills: {[s.name for s in user_skills]}'
)
except Exception as e:
_logger.warning(f'Failed to load user skills: {str(e)}')
user_skills = []
if not agent_server_url:
_logger.warning('No agent-server URL available, cannot load skills')
return []
# Load organization-level skills
org_skills = await load_org_skills(
remote_workspace, selected_repository, working_dir, self.user_context
)
# Build org config (authentication handled by app-server)
org_config = await build_org_config(selected_repository, self.user_context)
repo_skills = await load_repo_skills(
remote_workspace, selected_repository, working_dir
)
# Build sandbox config (exposed URLs)
sandbox_config = build_sandbox_config(sandbox)
# Merge all skills (later lists override earlier ones)
# Precedence: sandbox < global < user < org < repo
all_skills = merge_skills(
[sandbox_skills, global_skills, user_skills, org_skills, repo_skills]
# Determine project directory for project skills
project_dir = working_dir
if selected_repository:
repo_name = selected_repository.split('/')[-1]
project_dir = f'{working_dir}/{repo_name}'
# Single API call to agent-server for ALL skills
all_skills = await load_skills_from_agent_server(
agent_server_url=agent_server_url,
session_api_key=sandbox.session_api_key,
project_dir=project_dir,
org_config=org_config,
sandbox_config=sandbox_config,
load_public=True,
load_user=True,
load_project=True,
load_org=True,
)
_logger.info(
f'Loaded {len(all_skills)} total skills: {[s.name for s in all_skills]}'
f'Loaded {len(all_skills)} total skills from agent-server: '
f'{[s.name for s in all_skills]}'
)
return all_skills
except Exception as e:
_logger.warning(f'Failed to load skills: {e}', exc_info=True)
# Return empty list on failure - skills will be loaded again later if needed
return []
def _create_agent_with_skills(self, agent, skills: list):
def _create_agent_with_skills(self, agent, skills: list[Skill]):
"""Create or update agent with skills in its context.
Args:
@@ -132,9 +138,9 @@ class AppConversationServiceBase(AppConversationService, ABC):
Updated agent with skills in context
"""
if agent.agent_context:
# Merge with existing context
# Merge with existing context (new skills override existing ones)
existing_skills = agent.agent_context.skills
all_skills = merge_skills([skills, existing_skills])
all_skills = self._merge_skills([existing_skills, skills])
agent = agent.model_copy(
update={
'agent_context': agent.agent_context.model_copy(
@@ -149,6 +155,25 @@ class AppConversationServiceBase(AppConversationService, ABC):
return agent
def _merge_skills(self, skill_lists: list[list[Skill]]) -> list[Skill]:
"""Merge multiple skill lists, avoiding duplicates by name.
Later lists take precedence over earlier lists for duplicate names.
Args:
skill_lists: List of skill lists to merge
Returns:
Deduplicated list of skills with later lists overriding earlier ones
"""
skills_by_name: dict[str, Skill] = {}
for skill_list in skill_lists:
for skill in skill_list:
skills_by_name[skill.name] = skill
return list(skills_by_name.values())
async def _load_skills_and_update_agent(
self,
sandbox: SandboxInfo,
@@ -169,8 +194,10 @@ class AppConversationServiceBase(AppConversationService, ABC):
Updated agent with skills loaded into context
"""
# Load and merge all skills
# Extract agent_server_url from remote_workspace host
agent_server_url = remote_workspace.host
all_skills = await self.load_and_merge_all_skills(
sandbox, remote_workspace, selected_repository, working_dir
sandbox, selected_repository, working_dir, agent_server_url
)
# Update agent with skills
@@ -183,6 +210,7 @@ class AppConversationServiceBase(AppConversationService, ABC):
task: AppConversationStartTask,
sandbox: SandboxInfo,
workspace: AsyncRemoteWorkspace,
agent_server_url: str,
) -> AsyncGenerator[AppConversationStartTask, None]:
task.status = AppConversationStartTaskStatus.PREPARING_REPOSITORY
yield task
@@ -200,9 +228,9 @@ class AppConversationServiceBase(AppConversationService, ABC):
yield task
await self.load_and_merge_all_skills(
sandbox,
workspace,
task.request.selected_repository,
workspace.working_dir,
agent_server_url,
)
async def _configure_git_user_settings(
@@ -457,7 +485,6 @@ class AppConversationServiceBase(AppConversationService, ABC):
security_analyzer_str: String value from settings
httpx_client: HTTP client for making API requests
"""
if session_api_key is None:
return
@@ -32,6 +32,7 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartTask,
AppConversationStartTaskStatus,
AppConversationUpdateRequest,
PluginSpec,
)
from openhands.app_server.app_conversation.app_conversation_service import (
AppConversationService,
@@ -79,6 +80,7 @@ from openhands.experiments.experiment_manager import ExperimentManagerImpl
from openhands.integrations.provider import ProviderType
from openhands.sdk import Agent, AgentContext, LocalWorkspace
from openhands.sdk.llm import LLM
from openhands.sdk.plugin import PluginSource
from openhands.sdk.secret import LookupSecret, SecretValue, StaticSecret
from openhands.sdk.utils.paging import page_iterator
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
@@ -237,7 +239,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
working_dir=sandbox_spec.working_dir,
)
async for updated_task in self.run_setup_scripts(
task, sandbox, remote_workspace
task, sandbox, remote_workspace, agent_server_url
):
yield updated_task
@@ -254,6 +256,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
request.conversation_id,
remote_workspace=remote_workspace,
selected_repository=request.selected_repository,
plugins=request.plugins,
)
)
@@ -954,6 +957,79 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
return agent.model_copy(update=updates)
return agent
def _construct_initial_message_with_plugin_params(
self,
initial_message: SendMessageRequest | None,
plugins: list[PluginSpec] | None,
) -> SendMessageRequest | None:
"""Incorporate plugin parameters into the initial message if specified.
Plugin parameters are formatted and appended to the initial message so the
agent has context about the user-provided configuration values.
Args:
initial_message: The original initial message, if any
plugins: List of plugin specifications with optional parameters
Returns:
The initial message with plugin parameters incorporated, or the
original message if no plugin parameters are specified
"""
from openhands.agent_server.models import TextContent
if not plugins:
return initial_message
# Collect formatted parameters from plugins that have them
plugins_with_params = [p for p in plugins if p.parameters]
if not plugins_with_params:
return initial_message
# Format parameters, grouped by plugin if multiple
if len(plugins_with_params) == 1:
params_text = plugins_with_params[0].format_params_as_text()
plugin_params_message = (
f'\n\nPlugin Configuration Parameters:\n{params_text}'
)
else:
# Group by plugin name for clarity
formatted_plugins = []
for plugin in plugins_with_params:
params_text = plugin.format_params_as_text(indent=' ')
if params_text:
formatted_plugins.append(f'{plugin.display_name}:\n{params_text}')
plugin_params_message = (
'\n\nPlugin Configuration Parameters:\n' + '\n'.join(formatted_plugins)
)
if initial_message is None:
# Create a new message with just the plugin parameters
return SendMessageRequest(
content=[TextContent(text=plugin_params_message.strip())],
run=True,
)
# Append plugin parameters to existing message content
new_content = list(initial_message.content)
if new_content and isinstance(new_content[-1], TextContent):
# Append to the last text content
last_content = new_content[-1]
new_content[-1] = TextContent(
text=last_content.text + plugin_params_message,
cache_prompt=last_content.cache_prompt,
enable_truncation=last_content.enable_truncation,
)
else:
# Add as new text content
new_content.append(TextContent(text=plugin_params_message.strip()))
return SendMessageRequest(
role=initial_message.role,
content=new_content,
run=initial_message.run,
)
async def _finalize_conversation_request(
self,
agent: Agent,
@@ -966,6 +1042,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
remote_workspace: AsyncRemoteWorkspace | None,
selected_repository: str | None,
working_dir: str,
plugins: list[PluginSpec] | None = None,
) -> StartConversationRequest:
"""Finalize the conversation request with experiment variants and skills.
@@ -980,6 +1057,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
remote_workspace: Optional remote workspace for skills loading
selected_repository: Optional repository name
working_dir: Working directory path
plugins: Optional list of plugin specifications to load
Returns:
Complete StartConversationRequest ready for use
@@ -1006,6 +1084,23 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
_logger.warning(f'Failed to load skills: {e}', exc_info=True)
# Continue without skills - don't fail conversation startup
# Incorporate plugin parameters into initial message if specified
final_initial_message = self._construct_initial_message_with_plugin_params(
initial_message, plugins
)
# Convert PluginSpec list to SDK PluginSource list for agent server
sdk_plugins: list[PluginSource] | None = None
if plugins:
sdk_plugins = [
PluginSource(
source=p.source,
ref=p.ref,
repo_path=p.repo_path,
)
for p in plugins
]
# Create and return the final request
return StartConversationRequest(
conversation_id=conversation_id,
@@ -1014,8 +1109,9 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
confirmation_policy=self._select_confirmation_policy(
bool(user.confirmation_mode), user.security_analyzer
),
initial_message=initial_message,
initial_message=final_initial_message,
secrets=secrets,
plugins=sdk_plugins,
)
async def _build_start_conversation_request_for_user(
@@ -1030,6 +1126,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
conversation_id: UUID | None = None,
remote_workspace: AsyncRemoteWorkspace | None = None,
selected_repository: str | None = None,
plugins: list[PluginSpec] | None = None,
) -> StartConversationRequest:
"""Build a complete conversation request for a user.
@@ -1038,6 +1135,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
2. Configuring LLM and MCP settings
3. Creating an agent with appropriate context
4. Finalizing the request with skills and experiment variants
5. Passing plugins to the agent server for remote plugin loading
"""
user = await self.user_context.get_user_info()
workspace = LocalWorkspace(working_dir=working_dir)
@@ -1070,6 +1168,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
remote_workspace,
selected_repository,
working_dir,
plugins=plugins,
)
async def update_agent_server_conversation_title(
@@ -1124,7 +1223,8 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
self, conversation_id: UUID, request: AppConversationUpdateRequest
) -> AppConversation | None:
"""Update an app conversation and return it. Return None if the conversation
did not exist."""
did not exist.
"""
info = await self.app_conversation_info_service.get_app_conversation_info(
conversation_id
)
@@ -1295,7 +1395,7 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
# Get all events for this conversation
i = 0
async for event in page_iterator(
self.event_service.search_events, conversation_id__eq=conversation_id
self.event_service.search_events, conversation_id=conversation_id
):
event_filename = f'event_{i:06d}_{event.id}.json'
event_path = os.path.join(temp_dir, event_filename)
@@ -1,34 +1,37 @@
"""Utilities for loading skills for V1 conversations.
This module provides functions to load skills from various sources:
- Global skills from OpenHands/skills/
- User skills from ~/.openhands/skills/
- Repository-level skills from the workspace
This module provides functions to load skills from the agent-server,
which centralizes all skill loading logic. The app-server acts as a
thin proxy that:
1. Builds the org_config with authentication information
2. Builds the sandbox_config with exposed URLs
3. Calls the agent-server's /api/skills endpoint
All skills are used in V1 conversations.
All source-specific skill loading is handled by the agent-server.
"""
import logging
import os
from pathlib import Path
import openhands
import httpx
from pydantic import BaseModel
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
from openhands.app_server.user.user_context import UserContext
from openhands.integrations.provider import ProviderType
from openhands.integrations.service_types import AuthenticationError
from openhands.sdk.context.skills import Skill
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
from openhands.sdk.context.skills.trigger import KeywordTrigger, TaskTrigger
_logger = logging.getLogger(__name__)
# Path to global skills directory
GLOBAL_SKILLS_DIR = os.path.join(
os.path.dirname(os.path.dirname(openhands.__file__)),
'skills',
)
WORK_HOSTS_SKILL = """The user has access to the following hosts for accessing a web application,
each of which has a corresponding port:"""
class ExposedUrlConfig(BaseModel):
"""Configuration for an exposed URL in sandbox config."""
name: str
url: str
port: int
WORK_HOSTS_SKILL_FOOTER = """
When starting a web server, use the corresponding ports via environment variables:
@@ -45,96 +48,30 @@ app.run(host='0.0.0.0', port=int(os.environ.get('WORKER_1', 12000)))
```"""
def _find_and_load_global_skill_files(skill_dir: Path) -> list[Skill]:
"""Find and load all .md files from the global skills directory.
class SandboxConfig(BaseModel):
"""Sandbox configuration for agent-server API request."""
Args:
skill_dir: Path to the global skills directory
Returns:
List of Skill objects loaded from the files (excluding README.md)
"""
skills = []
try:
# Find all .md files in the directory (excluding README.md)
md_files = [f for f in skill_dir.glob('*.md') if f.name.lower() != 'readme.md']
# Load skills from the found files
for file_path in md_files:
try:
skill = Skill.load(file_path, skill_dir)
skills.append(skill)
_logger.debug(f'Loaded global skill: {skill.name} from {file_path}')
except Exception as e:
_logger.warning(
f'Failed to load global skill from {file_path}: {str(e)}'
)
except Exception as e:
_logger.debug(f'Failed to find global skill files: {str(e)}')
return skills
exposed_urls: list[ExposedUrlConfig]
def load_sandbox_skills(sandbox: SandboxInfo) -> list[Skill]:
"""Load skills specific to the sandbox, including exposed ports / urls."""
if not sandbox.exposed_urls:
return []
urls = [url for url in sandbox.exposed_urls if url.name.startswith('WORKER_')]
if not urls:
return []
content_list = [WORK_HOSTS_SKILL]
for url in urls:
content_list.append(f'* {url.url} (port {url.port})')
content_list.append(WORK_HOSTS_SKILL_FOOTER)
content = '\n'.join(content_list)
return [Skill(name='work_hosts', content=content, trigger=None)]
class OrgConfig(BaseModel):
"""Organization configuration for agent-server API request."""
repository: str
provider: str
org_repo_url: str
org_name: str
def load_global_skills() -> list[Skill]:
"""Load global skills from OpenHands/skills/ directory.
class SkillInfo(BaseModel):
"""Skill information from agent-server API response."""
Returns:
List of Skill objects loaded from global skills directory.
Returns empty list if directory doesn't exist or on errors.
"""
skill_dir = Path(GLOBAL_SKILLS_DIR)
# Check if directory exists
if not skill_dir.exists():
_logger.debug(f'Global skills directory does not exist: {skill_dir}')
return []
try:
_logger.info(f'Loading global skills from {skill_dir}')
# Find and load all .md files from the directory
skills = _find_and_load_global_skill_files(skill_dir)
_logger.info(f'Loaded {len(skills)} global skills: {[s.name for s in skills]}')
return skills
except Exception as e:
_logger.warning(f'Failed to load global skills: {str(e)}')
return []
def _determine_repo_root(working_dir: str, selected_repository: str | None) -> str:
"""Determine the repository root directory.
Args:
working_dir: Base working directory path
selected_repository: Repository name (e.g., 'owner/repo') or None
Returns:
Path to the repository root directory
"""
if selected_repository:
repo_name = selected_repository.split('/')[-1]
return f'{working_dir}/{repo_name}'
return working_dir
name: str
content: str
triggers: list[str] = []
source: str | None = None
description: str | None = None
is_agentskills_format: bool = False
async def _is_gitlab_repository(repo_name: str, user_context: UserContext) -> bool:
@@ -154,8 +91,6 @@ async def _is_gitlab_repository(repo_name: str, user_context: UserContext) -> bo
)
return repository.git_provider == ProviderType.GITLAB
except Exception:
# If we can't determine the provider, assume it's not GitLab
# This is a safe fallback since we'll just use the default .openhands
return False
@@ -178,10 +113,33 @@ async def _is_azure_devops_repository(
)
return repository.git_provider == ProviderType.AZURE_DEVOPS
except Exception:
# If we can't determine the provider, assume it's not Azure DevOps
return False
async def _get_provider_type(
selected_repository: str, user_context: UserContext
) -> str:
"""Determine the Git provider type for a repository.
Args:
selected_repository: Repository name (e.g., 'owner/repo')
user_context: UserContext to access provider handler
Returns:
Provider type string: 'github', 'gitlab', 'azure', or 'bitbucket'
"""
is_gitlab = await _is_gitlab_repository(selected_repository, user_context)
if is_gitlab:
return 'gitlab'
is_azure = await _is_azure_devops_repository(selected_repository, user_context)
if is_azure:
return 'azure'
# Default to github (covers github and bitbucket)
return 'github'
async def _determine_org_repo_path(
selected_repository: str, user_context: UserContext
) -> tuple[str, str]:
@@ -203,27 +161,19 @@ async def _determine_org_repo_path(
"""
repo_parts = selected_repository.split('/')
# Determine repository type
is_azure_devops = await _is_azure_devops_repository(
selected_repository, user_context
)
is_gitlab = await _is_gitlab_repository(selected_repository, user_context)
# Extract the org/user name
# Azure DevOps format: org/project/repo (3 parts) - extract org (first part)
# GitHub/GitLab/Bitbucket format: owner/repo (2 parts) - extract owner (first part)
if is_azure_devops and len(repo_parts) >= 3:
org_name = repo_parts[0] # Get org from org/project/repo
org_name = repo_parts[0]
else:
org_name = repo_parts[-2] # Get owner from owner/repo
org_name = repo_parts[-2]
# For GitLab and Azure DevOps, use openhands-config (since .openhands is not a valid repo name)
# For other providers, use .openhands
if is_gitlab:
org_openhands_repo = f'{org_name}/openhands-config'
elif is_azure_devops:
# Azure DevOps format: org/project/repo
# For org-level config, use: org/openhands-config/openhands-config
org_openhands_repo = f'{org_name}/openhands-config/openhands-config'
else:
org_openhands_repo = f'{org_name}/.openhands'
@@ -231,227 +181,6 @@ async def _determine_org_repo_path(
return org_openhands_repo, org_name
async def _read_file_from_workspace(
workspace: AsyncRemoteWorkspace, file_path: str, working_dir: str
) -> str | None:
"""Read file content from remote workspace.
Args:
workspace: AsyncRemoteWorkspace to execute commands
file_path: Path to the file to read
working_dir: Working directory for command execution
Returns:
File content as string, or None if file doesn't exist or read fails
"""
try:
result = await workspace.execute_command(
f'cat {file_path}', cwd=working_dir, timeout=10.0
)
if result.exit_code == 0 and result.stdout.strip():
return result.stdout
return None
except Exception as e:
_logger.debug(f'Failed to read file {file_path}: {str(e)}')
return None
async def _load_special_files(
workspace: AsyncRemoteWorkspace, repo_root: str, working_dir: str
) -> list[Skill]:
"""Load special skill files from repository root.
Loads: .cursorrules, agents.md, agent.md
Args:
workspace: AsyncRemoteWorkspace to execute commands
repo_root: Path to repository root directory
working_dir: Working directory for command execution
Returns:
List of Skill objects loaded from special files
"""
skills = []
special_files = ['.cursorrules', 'agents.md', 'agent.md']
for filename in special_files:
file_path = f'{repo_root}/{filename}'
content = await _read_file_from_workspace(workspace, file_path, working_dir)
if content:
try:
# Use simple string path to avoid Path filesystem operations
skill = Skill.load(path=filename, skill_dir=None, file_content=content)
skills.append(skill)
_logger.debug(f'Loaded special file skill: {skill.name}')
except Exception as e:
_logger.warning(f'Failed to create skill from {filename}: {str(e)}')
return skills
async def _find_and_load_skill_md_files(
workspace: AsyncRemoteWorkspace, skill_dir: str, working_dir: str
) -> list[Skill]:
"""Find and load all .md files from a skills directory in the workspace.
Args:
workspace: AsyncRemoteWorkspace to execute commands
skill_dir: Path to skills directory
working_dir: Working directory for command execution
Returns:
List of Skill objects loaded from the files (excluding README.md)
"""
skills = []
try:
# Find all .md files in the directory
result = await workspace.execute_command(
f"find {skill_dir} -type f -name '*.md' 2>/dev/null || true",
cwd=working_dir,
timeout=10.0,
)
if result.exit_code == 0 and result.stdout.strip():
file_paths = [
f.strip()
for f in result.stdout.strip().split('\n')
if f.strip() and 'README.md' not in f
]
# Load skills from the found files
for file_path in file_paths:
content = await _read_file_from_workspace(
workspace, file_path, working_dir
)
if content:
# Calculate relative path for skill name
rel_path = file_path.replace(f'{skill_dir}/', '')
try:
# Use simple string path to avoid Path filesystem operations
skill = Skill.load(
path=rel_path, skill_dir=None, file_content=content
)
skills.append(skill)
_logger.debug(f'Loaded repo skill: {skill.name}')
except Exception as e:
_logger.warning(
f'Failed to create skill from {rel_path}: {str(e)}'
)
except Exception as e:
_logger.debug(f'Failed to find skill files in {skill_dir}: {str(e)}')
return skills
def _merge_repo_skills_with_precedence(
special_skills: list[Skill],
skills_dir_skills: list[Skill],
microagents_dir_skills: list[Skill],
) -> list[Skill]:
"""Merge repository skills with precedence order.
Precedence (highest to lowest):
1. Special files (repo root)
2. .openhands/skills/ directory
3. .openhands/microagents/ directory (backward compatibility)
Args:
special_skills: Skills from special files in repo root
skills_dir_skills: Skills from .openhands/skills/ directory
microagents_dir_skills: Skills from .openhands/microagents/ directory
Returns:
Deduplicated list of skills with proper precedence
"""
# Use a dict to deduplicate by name, with earlier sources taking precedence
skills_by_name = {}
for skill in special_skills + skills_dir_skills + microagents_dir_skills:
# Only add if not already present (earlier sources win)
if skill.name not in skills_by_name:
skills_by_name[skill.name] = skill
return list(skills_by_name.values())
async def load_repo_skills(
workspace: AsyncRemoteWorkspace,
selected_repository: str | None,
working_dir: str,
) -> list[Skill]:
"""Load repository-level skills from the workspace.
Loads skills from:
1. Special files in repo root: .cursorrules, agents.md, agent.md
2. .md files in .openhands/skills/ directory (preferred)
3. .md files in .openhands/microagents/ directory (for backward compatibility)
Args:
workspace: AsyncRemoteWorkspace to execute commands in the sandbox
selected_repository: Repository name (e.g., 'owner/repo') or None
working_dir: Working directory path
Returns:
List of Skill objects loaded from repository.
Returns empty list on errors.
"""
try:
# Determine repository root directory
repo_root = _determine_repo_root(working_dir, selected_repository)
_logger.info(f'Loading repo skills from {repo_root}')
# Load special files from repo root
special_skills = await _load_special_files(workspace, repo_root, working_dir)
# Load .md files from .openhands/skills/ directory (preferred)
skills_dir = f'{repo_root}/.openhands/skills'
skills_dir_skills = await _find_and_load_skill_md_files(
workspace, skills_dir, working_dir
)
# Load .md files from .openhands/microagents/ directory (backward compatibility)
microagents_dir = f'{repo_root}/.openhands/microagents'
microagents_dir_skills = await _find_and_load_skill_md_files(
workspace, microagents_dir, working_dir
)
# Merge all loaded skills with proper precedence
all_skills = _merge_repo_skills_with_precedence(
special_skills, skills_dir_skills, microagents_dir_skills
)
_logger.info(
f'Loaded {len(all_skills)} repo skills: {[s.name for s in all_skills]}'
)
return all_skills
except Exception as e:
_logger.warning(f'Failed to load repo skills: {str(e)}')
return []
def _validate_repository_for_org_skills(selected_repository: str) -> bool:
"""Validate that the repository path has sufficient parts for org skills.
Args:
selected_repository: Repository name (e.g., 'owner/repo')
Returns:
True if repository is valid for org skills loading, False otherwise
"""
repo_parts = selected_repository.split('/')
if len(repo_parts) < 2:
_logger.warning(
f'Repository path has insufficient parts ({len(repo_parts)} < 2), skipping org-level skills'
)
return False
return True
async def _get_org_repository_url(
org_openhands_repo: str, user_context: UserContext
) -> str | None:
@@ -481,224 +210,193 @@ async def _get_org_repository_url(
return None
async def _clone_org_repository(
workspace: AsyncRemoteWorkspace,
remote_url: str,
org_repo_dir: str,
working_dir: str,
org_openhands_repo: str,
) -> bool:
"""Clone organization repository to temporary directory.
Args:
workspace: AsyncRemoteWorkspace to execute commands
remote_url: Authenticated Git URL
org_repo_dir: Temporary directory path for cloning
working_dir: Working directory for command execution
org_openhands_repo: Organization repository path (for logging)
Returns:
True if clone successful, False otherwise
"""
_logger.debug(f'Creating temporary directory for org repo: {org_repo_dir}')
# Clone the repo (shallow clone for efficiency)
clone_cmd = f'GIT_TERMINAL_PROMPT=0 git clone --depth 1 {remote_url} {org_repo_dir}'
_logger.info('Executing clone command for org-level repo')
result = await workspace.execute_command(clone_cmd, working_dir, timeout=120.0)
if result.exit_code != 0:
_logger.info(
f'No org-level skills found at {org_openhands_repo} (exit_code: {result.exit_code})'
)
_logger.debug(f'Clone command output: {result.stderr}')
return False
_logger.info(f'Successfully cloned org-level skills from {org_openhands_repo}')
return True
async def _load_skills_from_org_directories(
workspace: AsyncRemoteWorkspace, org_repo_dir: str, working_dir: str
) -> tuple[list[Skill], list[Skill]]:
"""Load skills from both skills/ and microagents/ directories in org repo.
Args:
workspace: AsyncRemoteWorkspace to execute commands
org_repo_dir: Path to cloned organization repository
working_dir: Working directory for command execution
Returns:
Tuple of (skills_dir_skills, microagents_dir_skills)
"""
skills_dir = f'{org_repo_dir}/skills'
skills_dir_skills = await _find_and_load_skill_md_files(
workspace, skills_dir, working_dir
)
microagents_dir = f'{org_repo_dir}/microagents'
microagents_dir_skills = await _find_and_load_skill_md_files(
workspace, microagents_dir, working_dir
)
return skills_dir_skills, microagents_dir_skills
def _merge_org_skills_with_precedence(
skills_dir_skills: list[Skill], microagents_dir_skills: list[Skill]
) -> list[Skill]:
"""Merge skills from skills/ and microagents/ with proper precedence.
Precedence: skills/ > microagents/ (skills/ overrides microagents/ for same name)
Args:
skills_dir_skills: Skills loaded from skills/ directory
microagents_dir_skills: Skills loaded from microagents/ directory
Returns:
Merged list of skills with proper precedence applied
"""
skills_by_name = {}
for skill in microagents_dir_skills + skills_dir_skills:
# Later sources (skills/) override earlier ones (microagents/)
if skill.name not in skills_by_name:
skills_by_name[skill.name] = skill
else:
_logger.debug(
f'Overriding org skill "{skill.name}" from microagents/ with skills/'
)
skills_by_name[skill.name] = skill
return list(skills_by_name.values())
async def _cleanup_org_repository(
workspace: AsyncRemoteWorkspace, org_repo_dir: str, working_dir: str
) -> None:
"""Clean up cloned organization repository directory.
Args:
workspace: AsyncRemoteWorkspace to execute commands
org_repo_dir: Path to cloned organization repository
working_dir: Working directory for command execution
"""
cleanup_cmd = f'rm -rf {org_repo_dir}'
await workspace.execute_command(cleanup_cmd, working_dir, timeout=10.0)
async def load_org_skills(
workspace: AsyncRemoteWorkspace,
async def build_org_config(
selected_repository: str | None,
working_dir: str,
user_context: UserContext,
) -> list[Skill]:
"""Load organization-level skills from the organization repository.
For example, if the repository is github.com/acme-co/api, this will check if
github.com/acme-co/.openhands exists. If it does, it will clone it and load
the skills from both the ./skills/ and ./microagents/ folders.
For GitLab repositories, it will use openhands-config instead of .openhands
since GitLab doesn't support repository names starting with non-alphanumeric
characters.
For Azure DevOps repositories, it will use org/openhands-config/openhands-config
format to match Azure DevOps's three-part repository structure (org/project/repo).
) -> OrgConfig | None:
"""Build organization config for agent-server API request.
Args:
workspace: AsyncRemoteWorkspace to execute commands in the sandbox
selected_repository: Repository name (e.g., 'owner/repo') or None
working_dir: Working directory path
user_context: UserContext to access provider handler and authentication
user_context: UserContext to access authentication and provider info
Returns:
List of Skill objects loaded from organization repository.
Returns empty list if no repository selected or on errors.
org_config dict if org repository exists and is accessible, None otherwise
"""
if not selected_repository:
return []
return None
repo_parts = selected_repository.split('/')
if len(repo_parts) < 2:
_logger.warning(
f'Repository path has insufficient parts ({len(repo_parts)} < 2), '
f'skipping org-level skills'
)
return None
try:
_logger.debug(
f'Starting org-level skill loading for repository: {selected_repository}'
)
# Validate repository path
if not _validate_repository_for_org_skills(selected_repository):
return []
# Determine organization repository path
org_openhands_repo, org_name = await _determine_org_repo_path(
selected_repository, user_context
)
_logger.info(f'Checking for org-level skills at {org_openhands_repo}')
org_repo_url = await _get_org_repository_url(org_openhands_repo, user_context)
if not org_repo_url:
return None
# Get authenticated URL for org repository
remote_url = await _get_org_repository_url(org_openhands_repo, user_context)
if not remote_url:
return []
provider = await _get_provider_type(selected_repository, user_context)
# Clone the organization repository
org_repo_dir = f'{working_dir}/_org_openhands_{org_name}'
clone_success = await _clone_org_repository(
workspace, remote_url, org_repo_dir, working_dir, org_openhands_repo
)
if not clone_success:
return []
# Load skills from both skills/ and microagents/ directories
(
skills_dir_skills,
microagents_dir_skills,
) = await _load_skills_from_org_directories(
workspace, org_repo_dir, working_dir
return OrgConfig(
repository=selected_repository,
provider=provider,
org_repo_url=org_repo_url,
org_name=org_name,
)
# Merge skills with proper precedence
loaded_skills = _merge_org_skills_with_precedence(
skills_dir_skills, microagents_dir_skills
)
_logger.info(
f'Loaded {len(loaded_skills)} skills from org-level repository {org_openhands_repo}: {[s.name for s in loaded_skills]}'
)
# Clean up the org repo directory
await _cleanup_org_repository(workspace, org_repo_dir, working_dir)
return loaded_skills
except AuthenticationError as e:
_logger.debug(f'org-level skill directory not found: {str(e)}')
return []
except Exception as e:
_logger.warning(f'Failed to load org-level skills: {str(e)}')
return []
_logger.debug(f'Failed to build org config: {str(e)}')
return None
def merge_skills(skill_lists: list[list[Skill]]) -> list[Skill]:
"""Merge multiple skill lists, avoiding duplicates by name.
Later lists take precedence over earlier lists for duplicate names.
def build_sandbox_config(sandbox: SandboxInfo) -> SandboxConfig | None:
"""Build sandbox config for agent-server API request.
Args:
skill_lists: List of skill lists to merge
sandbox: SandboxInfo containing exposed URLs
Returns:
Deduplicated list of skills with later lists overriding earlier ones
sandbox_config dict if there are exposed URLs, None otherwise
"""
skills_by_name = {}
if not sandbox.exposed_urls:
return None
for skill_list in skill_lists:
for skill in skill_list:
if skill.name in skills_by_name:
_logger.debug(
f'Overriding skill "{skill.name}" from earlier source with later source'
exposed_urls = [
ExposedUrlConfig(name=url.name, url=url.url, port=url.port)
for url in sandbox.exposed_urls
]
return SandboxConfig(exposed_urls=exposed_urls)
async def load_skills_from_agent_server(
agent_server_url: str,
session_api_key: str | None,
project_dir: str,
org_config: OrgConfig | None = None,
sandbox_config: SandboxConfig | None = None,
load_public: bool = True,
load_user: bool = True,
load_project: bool = True,
load_org: bool = True,
) -> list[Skill]:
"""Load all skills from the agent-server.
This function makes a single API call to the agent-server's /api/skills
endpoint to load and merge skills from all configured sources.
Args:
agent_server_url: URL of the agent server (e.g., 'http://localhost:8000')
session_api_key: Session API key for authentication (optional)
project_dir: Workspace directory path for project skills
org_config: Organization skills configuration (optional)
sandbox_config: Sandbox skills configuration (optional)
load_public: Whether to load public skills (default: True)
load_user: Whether to load user skills (default: True)
load_project: Whether to load project skills (default: True)
load_org: Whether to load organization skills (default: True)
Returns:
List of Skill objects merged from all sources.
Returns empty list on error.
"""
try:
# Build request payload
payload = {
'load_public': load_public,
'load_user': load_user,
'load_project': load_project,
'load_org': load_org,
'project_dir': project_dir,
'org_config': org_config.model_dump() if org_config else None,
'sandbox_config': sandbox_config.model_dump() if sandbox_config else None,
}
# Build headers
headers = {'Content-Type': 'application/json'}
if session_api_key:
headers['X-Session-API-Key'] = session_api_key
# Make API request
async with httpx.AsyncClient() as client:
response = await client.post(
f'{agent_server_url}/api/skills',
json=payload,
headers=headers,
timeout=60.0,
)
response.raise_for_status()
data = response.json()
# Convert response to Skill objects
skills: list[Skill] = []
for skill_data_dict in data.get('skills', []):
try:
skill_info = SkillInfo.model_validate(skill_data_dict)
skill = _convert_skill_info_to_skill(skill_info)
skills.append(skill)
except Exception as e:
skill_name = (
skill_data_dict.get('name', 'unknown')
if isinstance(skill_data_dict, dict)
else 'unknown'
)
skills_by_name[skill.name] = skill
_logger.warning(f'Failed to convert skill {skill_name}: {e}')
result = list(skills_by_name.values())
_logger.debug(f'Merged skills: {[s.name for s in result]}')
return result
sources = data.get('sources', {})
_logger.info(
f'Loaded {len(skills)} skills from agent-server: '
f'sources={sources}, names={[s.name for s in skills]}'
)
return skills
except httpx.HTTPStatusError as e:
_logger.warning(
f'Agent-server returned error status {e.response.status_code}: '
f'{e.response.text}'
)
return []
except httpx.RequestError as e:
_logger.warning(f'Failed to connect to agent-server: {e}')
return []
except Exception as e:
_logger.warning(f'Failed to load skills from agent-server: {e}')
return []
def _convert_skill_info_to_skill(skill_info: SkillInfo) -> Skill:
"""Convert skill info from API response to Skill object.
Args:
skill_info: SkillInfo model from API response
Returns:
Skill object
"""
trigger = None
if skill_info.triggers:
# Determine trigger type based on content
if any(t.startswith('/') for t in skill_info.triggers):
trigger = TaskTrigger(triggers=skill_info.triggers)
else:
trigger = KeywordTrigger(keywords=skill_info.triggers)
return Skill(
name=skill_info.name,
content=skill_info.content,
trigger=trigger,
source=skill_info.source,
description=skill_info.description,
is_agentskills_format=skill_info.is_agentskills_format,
)
@@ -541,7 +541,8 @@ class SQLAppConversationInfoService(AppConversationInfoService):
def _fix_timezone(self, value: datetime) -> datetime:
"""Sqlite does not stpre timezones - and since we can't update the existing models
we assume UTC if the timezone is missing."""
we assume UTC if the timezone is missing.
"""
if not value.tzinfo:
value = value.replace(tzinfo=UTC)
return value
@@ -68,7 +68,8 @@ class StoredAppConversationStartTask(Base): # type: ignore
class SQLAppConversationStartTaskService(AppConversationStartTaskService):
"""SQL implementation of AppConversationStartTaskService focused on db operations.
This allows storing and retrieving conversation start tasks from the database."""
This allows storing and retrieving conversation start tasks from the database.
"""
session: AsyncSession
user_id: str | None = None
@@ -13,7 +13,7 @@ from openhands.sdk.utils.models import DiscriminatedUnionMixin
# The version of the agent server to use for deployments.
# Typically this will be the same as the values from the pyproject.toml
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:0fdea73-python'
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:c775ff6-python'
class SandboxSpecService(ABC):
-12
View File
@@ -1,12 +0,0 @@
# OpenHands Architecture
This document provides detailed architecture diagrams and explanations for the OpenHands system.
## Documentation Sections
- [System Architecture Overview](./system-architecture.md)
- [Conversation Startup & WebSocket Flow](./conversation-startup.md)
- [Authentication Flow](./authentication.md)
- [Agent Execution & LLM Flow](./agent-execution.md)
- [External Integrations](./external-integrations.md)
- [Metrics, Logs & Observability](./observability.md)
-96
View File
@@ -1,96 +0,0 @@
# Agent Execution & LLM Flow
When the agent executes inside the sandbox, it makes LLM calls through LiteLLM:
```mermaid
sequenceDiagram
autonumber
participant User as User (Browser)
participant AS as Agent Server
participant Agent as Agent<br/>(CodeAct)
participant LLM as LLM Class
participant Lite as LiteLLM
participant Proxy as LLM Proxy<br/>(llm-proxy.app.all-hands.dev)
participant Provider as LLM Provider<br/>(OpenAI, Anthropic, etc.)
participant AES as Action Execution Server
Note over User,AES: Agent Loop - LLM Call Flow
User->>AS: WebSocket: User message
AS->>Agent: Process message
Agent->>Agent: Build prompt from state
Agent->>LLM: completion(messages, tools)
LLM->>LLM: Apply config (model, temp, etc.)
alt Using OpenHands Provider
LLM->>Lite: litellm_proxy/{model}
Lite->>Proxy: POST /chat/completions
Proxy->>Proxy: Auth, rate limit, routing
Proxy->>Provider: Forward request
Provider-->>Proxy: Response
Proxy-->>Lite: Response
else Using Direct Provider
LLM->>Lite: {provider}/{model}
Lite->>Provider: Direct API call
Provider-->>Lite: Response
end
Lite-->>LLM: ModelResponse
LLM->>LLM: Track metrics (cost, tokens)
LLM-->>Agent: Parsed response
Agent->>Agent: Parse action from response
AS->>User: WebSocket: Action event
Note over User,AES: Action Execution
AS->>AES: HTTP: Execute action
AES->>AES: Run command/edit file
AES-->>AS: Observation
AS->>User: WebSocket: Observation event
Agent->>Agent: Update state
Note over Agent: Loop continues...
```
### LLM Components
| Component | Purpose | Location |
|-----------|---------|----------|
| **LLM Class** | Wrapper with retries, metrics, config | `openhands/llm/llm.py` |
| **LiteLLM** | Universal LLM API adapter | External library |
| **LLM Proxy** | OpenHands managed proxy for billing/routing | `llm-proxy.app.all-hands.dev` |
| **LLM Registry** | Manages multiple LLM instances | `openhands/llm/llm_registry.py` |
### Model Routing
```
User selects model
┌───────────────────┐
│ Model prefix? │
└───────────────────┘
├── openhands/claude-3-5 ──► Rewrite to litellm_proxy/claude-3-5
│ Base URL: llm-proxy.app.all-hands.dev
├── anthropic/claude-3-5 ──► Direct to Anthropic API
│ (User's API key)
├── openai/gpt-4 ──► Direct to OpenAI API
│ (User's API key)
└── azure/gpt-4 ──► Direct to Azure OpenAI
(User's API key + endpoint)
```
### LLM Proxy Benefits
When using `openhands/` prefixed models:
- **Unified Billing**: Costs tracked through OpenHands account
- **No API Keys Needed**: Users don't need their own provider keys
- **Rate Limiting**: Managed quotas and throttling
- **Model Routing**: Automatic failover and load balancing
- **Usage Tracking**: Detailed metrics and cost analysis
-58
View File
@@ -1,58 +0,0 @@
# Authentication Flow
OpenHands uses Keycloak for identity management in the SaaS deployment. The authentication flow involves multiple services:
```mermaid
sequenceDiagram
autonumber
participant User as User (Browser)
participant App as App Server
participant KC as Keycloak
participant IdP as Identity Provider<br/>(GitHub, Google, etc.)
participant DB as User Database
Note over User,DB: OAuth 2.0 / OIDC Authentication Flow
User->>App: Access OpenHands
App->>User: Redirect to Keycloak
User->>KC: Login request
KC->>User: Show login options
User->>KC: Select provider (e.g., GitHub)
KC->>IdP: OAuth redirect
User->>IdP: Authenticate
IdP-->>KC: OAuth callback + tokens
KC->>KC: Create/update user session
KC-->>User: Redirect with auth code
User->>App: Auth code
App->>KC: Exchange code for tokens
KC-->>App: Access token + Refresh token
App->>App: Create signed JWT cookie
App->>DB: Store/update user record
App-->>User: Set keycloak_auth cookie
Note over User,DB: Subsequent Requests
User->>App: Request with cookie
App->>App: Verify JWT signature
App->>KC: Validate token (if needed)
KC-->>App: Token valid
App->>App: Extract user context
App-->>User: Authorized response
```
### Authentication Components
| Component | Purpose | Location |
|-----------|---------|----------|
| **Keycloak** | Identity provider, SSO, token management | External service |
| **UserAuth** | Abstract auth interface | `openhands/server/user_auth/user_auth.py` |
| **SaasUserAuth** | Keycloak implementation | `enterprise/server/auth/saas_user_auth.py` |
| **JWT Service** | Token signing/verification | `openhands/app_server/services/jwt_service.py` |
| **Auth Routes** | Login/logout endpoints | `enterprise/server/routes/auth.py` |
### Token Flow
1. **Keycloak Access Token**: Short-lived token for API access
2. **Keycloak Refresh Token**: Long-lived token to obtain new access tokens
3. **Signed JWT Cookie**: App Server's session cookie containing encrypted Keycloak tokens
4. **Provider Tokens**: OAuth tokens for GitHub, GitLab, etc. (stored separately for git operations)
@@ -1,68 +0,0 @@
# Conversation Startup & WebSocket Flow
When a user starts a conversation, this sequence occurs:
```mermaid
sequenceDiagram
autonumber
participant User as User (Browser)
participant App as App Server
participant SS as Sandbox Service
participant RAPI as Runtime API
participant Pool as Warm Pool
participant Sandbox as Sandbox (Container)
participant AS as Agent Server
participant AES as Action Execution Server
Note over User,AES: Phase 1: Conversation Creation
User->>App: POST /api/conversations
App->>App: Authenticate user
App->>SS: Create sandbox
Note over SS,Pool: Phase 2: Runtime Provisioning
SS->>RAPI: POST /start (image, env, config)
RAPI->>Pool: Check for warm runtime
alt Warm runtime available
Pool-->>RAPI: Return warm runtime
RAPI->>RAPI: Assign to session
else No warm runtime
RAPI->>Sandbox: Create new container
Sandbox->>AS: Start Agent Server
Sandbox->>AES: Start Action Execution Server
AES-->>AS: Ready
end
RAPI-->>SS: Runtime URL + session API key
SS-->>App: Sandbox info
App-->>User: Conversation ID + Sandbox URL
Note over User,AES: Phase 3: Direct WebSocket Connection
User->>AS: WebSocket: /sockets/events/{id}
AS-->>User: Connection accepted
AS->>User: Replay historical events
Note over User,AES: Phase 4: User Sends Message
User->>AS: WebSocket: SendMessageRequest
AS->>AS: Agent processes message
AS->>AS: LLM call → generate action
Note over User,AES: Phase 5: Action Execution Loop
loop Agent Loop
AS->>AES: HTTP: Execute action
AES->>AES: Run in sandbox
AES-->>AS: Observation result
AS->>User: WebSocket: Event update
AS->>AS: Update state, next action
end
Note over User,AES: Phase 6: Task Complete
AS->>User: WebSocket: AgentStateChanged (FINISHED)
```
### Key Points
1. **Initial Setup via App Server**: The App Server handles authentication and coordinates with the Sandbox Service
2. **Runtime API Provisioning**: The Sandbox Service calls the Runtime API, which checks for warm runtimes before creating new containers
3. **Warm Pool Optimization**: Pre-warmed runtimes reduce startup latency significantly
4. **Direct WebSocket to Sandbox**: Once created, the user's browser connects **directly** to the Agent Server inside the sandbox
5. **App Server Not in Hot Path**: After connection, all real-time communication bypasses the App Server entirely
6. **Agent Server Orchestrates**: The Agent Server manages the AI loop, calling the Action Execution Server for actual command execution
@@ -1,88 +0,0 @@
# External Integrations
OpenHands integrates with external services (GitHub, Slack, Jira, etc.) through webhook-based event handling:
```mermaid
sequenceDiagram
autonumber
participant Ext as External Service<br/>(GitHub/Slack/Jira)
participant App as App Server
participant IntRouter as Integration Router
participant Manager as Integration Manager
participant Conv as Conversation Service
participant Sandbox as Sandbox
Note over Ext,Sandbox: Webhook Event Flow (e.g., GitHub Issue Created)
Ext->>App: POST /api/integration/{service}/events
App->>IntRouter: Route to service handler
IntRouter->>IntRouter: Verify signature<br/>(HMAC/signing secret)
IntRouter->>Manager: Parse event payload
Manager->>Manager: Extract context<br/>(repo, issue, user)
Manager->>Manager: Map external user → OpenHands user<br/>(via stored tokens)
Manager->>Conv: Create conversation<br/>(with issue context)
Conv->>Sandbox: Provision sandbox
Sandbox-->>Conv: Ready
Manager->>Sandbox: Start agent with task
Note over Ext,Sandbox: Agent Works on Task...
Sandbox-->>Manager: Task complete
Manager->>Ext: POST result<br/>(PR, comment, etc.)
Note over Ext,Sandbox: Callback Flow (Agent → External Service)
Sandbox->>App: Webhook callback<br/>/api/v1/webhooks
App->>Manager: Process callback
Manager->>Ext: Update external service
```
### Supported Integrations
| Integration | Trigger Events | Agent Actions |
|-------------|----------------|---------------|
| **GitHub** | Issue created, PR opened, @mention | Create PR, comment, push commits |
| **GitLab** | Issue created, MR opened | Create MR, comment, push commits |
| **Slack** | @mention in channel | Reply in thread, create tasks |
| **Jira** | Issue created/updated | Update ticket, add comments |
| **Linear** | Issue created | Update status, add comments |
### Integration Components
| Component | Purpose | Location |
|-----------|---------|----------|
| **Integration Routes** | Webhook endpoints per service | `enterprise/server/routes/integration/` |
| **Integration Managers** | Business logic per service | `enterprise/integrations/{service}/` |
| **Token Manager** | Store/retrieve OAuth tokens | `enterprise/server/auth/token_manager.py` |
| **Callback Processor** | Handle agent → service updates | `enterprise/integrations/{service}/*_callback_processor.py` |
### Integration Authentication
```
External Service (e.g., GitHub)
┌─────────────────────────────────┐
│ GitHub App Installation │
│ - Webhook secret for signature │
│ - App private key for API calls │
└─────────────────────────────────┘
┌─────────────────────────────────┐
│ User Account Linking │
│ - Keycloak user ID │
│ - GitHub user ID │
│ - Stored OAuth tokens │
└─────────────────────────────────┘
┌─────────────────────────────────┐
│ Agent Execution │
│ - Uses linked tokens for API │
│ - Can push, create PRs, comment │
└─────────────────────────────────┘
```
-103
View File
@@ -1,103 +0,0 @@
# Metrics, Logs & Observability
OpenHands uses multiple systems for monitoring, analytics, and debugging:
```mermaid
flowchart LR
subgraph Sources["Sources"]
Agent["Agent Server"]
App["App Server"]
Frontend["Frontend"]
end
subgraph Collection["Collection"]
JSONLog["JSON Logs"]
Metrics["Metrics"]
PH["PostHog"]
end
subgraph Services["Services"]
DD["DataDog"]
PHCloud["PostHog Cloud"]
end
Agent --> JSONLog
App --> JSONLog
App --> PH
Frontend --> PH
JSONLog --> DD
Metrics --> DD
PH --> PHCloud
```
### Logging Infrastructure
| Component | Format | Destination | Purpose |
|-----------|--------|-------------|---------|
| **Application Logs** | JSON (when `LOG_JSON=1`) | stdout → DataDog | Debugging, error tracking |
| **Access Logs** | JSON (Uvicorn) | stdout → DataDog | Request tracing |
| **LLM Debug Logs** | Plain text | File (optional) | LLM call debugging |
### JSON Log Format
When `LOG_JSON=1` is set, all logs are emitted as single-line JSON for DataDog ingestion:
```json
{
"message": "Conversation started",
"severity": "INFO",
"conversation_id": "abc-123",
"user_id": "user-456",
"timestamp": "2024-01-15T10:30:00Z"
}
```
### Metrics Tracked
| Metric | Tracked By | Storage | Purpose |
|--------|------------|---------|---------|
| **LLM Cost** | `Metrics` class | Conversation stats file | Billing, budget limits |
| **Token Usage** | `Metrics` class | Conversation stats file | Usage analytics |
| **Response Latency** | `Metrics` class | Conversation stats file | Performance monitoring |
| **User Events** | PostHog | PostHog Cloud | Product analytics |
| **Feature Flags** | PostHog | PostHog Cloud | Gradual rollouts |
### PostHog Analytics
PostHog is used for both product analytics and feature flags:
**Frontend Events:**
- `conversation_started`
- `download_trajectory_button_clicked`
- Feature flag checks
**Backend Events:**
- Experiment assignments
- Conversion tracking
### DataDog Integration
Logs are ingested by DataDog through structured JSON output:
1. **Log Collection**: Container stdout/stderr → DataDog Agent → DataDog Logs
2. **APM Traces**: Distributed tracing across services (when enabled)
3. **Dashboards**: Custom dashboards for:
- Error rates by service
- Request latency percentiles
- Conversation success rates
- LLM cost tracking
### Conversation Stats Persistence
Per-conversation metrics are persisted for billing and analytics:
```python
# Location: openhands/server/services/conversation_stats.py
ConversationStats:
- service_to_metrics: Dict[str, Metrics]
- accumulated_cost: float
- token_usage: TokenUsage
# Stored at: {file_store}/conversation_stats/{conversation_id}.pkl
```
@@ -1,64 +0,0 @@
# System Architecture Overview
OpenHands uses a multi-tier architecture with these main components:
```mermaid
flowchart TB
subgraph AppServer["OpenHands App Server (Single Instance)"]
API["REST API<br/>(FastAPI)"]
Auth["Authentication"]
ConvMgr["Conversation<br/>Manager"]
SandboxSvc["Sandbox<br/>Service"]
end
subgraph RuntimeAPI["Runtime API (Separate Service)"]
RuntimeMgr["Runtime<br/>Manager"]
WarmPool["Warm Runtime<br/>Pool"]
end
subgraph Sandbox["Sandbox (Docker/K8s Container)"]
AS["Agent Server<br/>(openhands-agent-server)"]
AES["Action Execution<br/>Server"]
Browser["Browser<br/>Environment"]
FS["File System"]
end
User["User"] -->|"1. HTTP/REST"| API
API --> Auth
Auth --> ConvMgr
ConvMgr --> SandboxSvc
SandboxSvc -->|"2. POST /start"| RuntimeMgr
RuntimeMgr -->|"Check pool"| WarmPool
WarmPool -->|"Warm runtime<br/>available?"| RuntimeMgr
RuntimeMgr -->|"3. Provision or<br/>assign runtime"| Sandbox
User -.->|"4. WebSocket<br/>(Direct)"| AS
AS -->|"HTTP"| AES
AES --> Browser
AES --> FS
```
### Component Responsibilities
| Component | Location | Instances | Purpose |
|-----------|----------|-----------|---------|
| **App Server** | Host | 1 per deployment | REST API, auth, conversation management |
| **Sandbox Service** | Inside App Server | 1 | Manages sandbox lifecycle, calls Runtime API |
| **Runtime API** | Separate service | 1 per deployment | Provisions runtimes, manages warm pool |
| **Agent Server** | Inside sandbox | 1 per sandbox | AI agent loop, LLM calls, state management |
| **Action Execution Server** | Inside sandbox | 1 per sandbox | Execute bash, file ops, browser actions |
### Runtime API Endpoints
The Runtime API manages the actual container/pod lifecycle:
| Endpoint | Purpose |
|----------|---------|
| `POST /start` | Start a new runtime (or assign from warm pool) |
| `POST /stop` | Stop and clean up a runtime |
| `POST /pause` | Pause a running runtime |
| `POST /resume` | Resume a paused runtime |
| `GET /sessions/{id}` | Get runtime status |
| `GET /list` | List all active runtimes |
+16 -1
View File
@@ -136,7 +136,7 @@ class LLM(RetryMixin, DebugMixin):
if self.config.model.startswith('openhands/'):
model_name = self.config.model.removeprefix('openhands/')
self.config.model = f'litellm_proxy/{model_name}'
self.config.base_url = 'https://llm-proxy.app.all-hands.dev/'
self.config.base_url = _get_openhands_llm_base_url()
logger.debug(
f'Rewrote openhands/{model_name} to {self.config.model} with base URL {self.config.base_url}'
)
@@ -851,3 +851,18 @@ class LLM(RetryMixin, DebugMixin):
# let pydantic handle the serialization
return [message.model_dump() for message in messages]
def _get_openhands_llm_base_url():
# Get the API url if specified
lite_llm_api_url = os.getenv('LITE_LLM_API_URL')
if lite_llm_api_url:
return lite_llm_api_url
# Fallback to using web_host.
web_host = os.getenv('WEB_HOST')
if web_host and ('.staging.' in web_host or web_host.startswith('staging')):
return 'https://llm-proxy.staging.all-hands.dev/'
# Use the default
return 'https://llm-proxy.app.all-hands.dev/'
+5 -1
View File
@@ -1,5 +1,7 @@
from __future__ import annotations
from typing import Annotated
from pydantic import (
BaseModel,
ConfigDict,
@@ -31,7 +33,9 @@ class Settings(BaseModel):
user_version: int | None = None
remote_runtime_resource_factor: int | None = None
# Planned to be removed from settings
secrets_store: Secrets = Field(default_factory=Secrets, frozen=True)
secrets_store: Annotated[Secrets, Field(frozen=True)] = Field(
default_factory=Secrets
)
enable_default_condenser: bool = True
enable_sound_notifications: bool = False
enable_proactive_conversation_starters: bool = True
Generated
+10 -10
View File
@@ -7731,14 +7731,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
[[package]]
name = "openhands-agent-server"
version = "1.9.0"
version = "1.10.0"
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_agent_server-1.9.0-py3-none-any.whl", hash = "sha256:44b65fac5bb831541eb2e8726afb2682bde4816b4c6c90be9ad3cafd3dbcf971"},
{file = "openhands_agent_server-1.9.0.tar.gz", hash = "sha256:ac41a948acf64ed661a9f383c293c305176f92bd12e6fc6362f5414cb7874ee1"},
{file = "openhands_agent_server-1.10.0-py3-none-any.whl", hash = "sha256:2e21076fff5e7cf9d03a3b011e2c90a6a3a46d2da3f18db9f7553ac413229c22"},
{file = "openhands_agent_server-1.10.0.tar.gz", hash = "sha256:2062da2496a98a6c23201d086f124e02329d6c6d9d1b47be55921c084a29f55a"},
]
[package.dependencies]
@@ -7755,14 +7755,14 @@ wsproto = ">=1.2.0"
[[package]]
name = "openhands-sdk"
version = "1.9.0"
version = "1.10.0"
description = "OpenHands SDK - Core functionality for building AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_sdk-1.9.0-py3-none-any.whl", hash = "sha256:b427d8b9e587a5360c7d61742c290601998557e9b38b1c9e11a297659812c00d"},
{file = "openhands_sdk-1.9.0.tar.gz", hash = "sha256:70048888fd4fbe44a86c35c402bbb99d30cf0cba50579ee1a8e3f43e05154150"},
{file = "openhands_sdk-1.10.0-py3-none-any.whl", hash = "sha256:5c8875f2a07d7fabe3449914639572bef9003821207cb06aa237a239e964eed5"},
{file = "openhands_sdk-1.10.0.tar.gz", hash = "sha256:93371b1af4532266ad2d225b9d7d3d711c745df31888efe643970673f62bdef9"},
]
[package.dependencies]
@@ -7783,14 +7783,14 @@ boto3 = ["boto3 (>=1.35.0)"]
[[package]]
name = "openhands-tools"
version = "1.9.0"
version = "1.10.0"
description = "OpenHands Tools - Runtime tools for AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_tools-1.9.0-py3-none-any.whl", hash = "sha256:8becde0e913a31babb41eb93a8c10bf41d87ca1febd07bc958839c3583655305"},
{file = "openhands_tools-1.9.0.tar.gz", hash = "sha256:d45f5f5210cb2bbcd8ab5f3a32051db1a532d0ec07cd306105f95cde42cf67f2"},
{file = "openhands_tools-1.10.0-py3-none-any.whl", hash = "sha256:1d5d2d1e34cc4ceb02c0ff1f008b06883ad48a8e7236ab8dd61ece64fbf8e2ed"},
{file = "openhands_tools-1.10.0.tar.gz", hash = "sha256:7ed38cb13545ec2c4a35c26ece725d5b35788d30597db8b1904619c043ec1194"},
]
[package.dependencies]
@@ -17367,4 +17367,4 @@ third-party-runtimes = ["daytona", "e2b-code-interpreter", "modal", "runloop-api
[metadata]
lock-version = "2.1"
python-versions = "^3.12,<3.14"
content-hash = "af2159c3b8723a036d7c3f3ddd0b45ce149acd20d164c17856be7db48a35c695"
content-hash = "f67478db2385eb258369313ac831b26582d744294c0996a35e786c3d7ced5db1"
+6 -9
View File
@@ -54,9 +54,9 @@ dependencies = [
"numpy",
"openai==2.8",
"openhands-aci==0.3.2",
"openhands-agent-server==1.9",
"openhands-sdk==1.9",
"openhands-tools==1.9",
"openhands-agent-server==1.10",
"openhands-sdk==1.10",
"openhands-tools==1.10",
"opentelemetry-api>=1.33.1",
"opentelemetry-exporter-otlp-proto-grpc>=1.33.1",
"pathspec>=0.12.1",
@@ -280,12 +280,9 @@ e2b-code-interpreter = { version = "^2.0.0", optional = true }
pybase62 = "^1.0.0"
# V1 dependencies
#openhands-agent-server = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-agent-server", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
#openhands-sdk = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-sdk", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
#openhands-tools = { git = "https://github.com/OpenHands/agent-sdk.git", subdirectory = "openhands-tools", rev = "15f565b8ac38876e40dc05c08e2b04ccaae4a66d" }
openhands-sdk = "1.9.0"
openhands-agent-server = "1.9.0"
openhands-tools = "1.9.0"
openhands-sdk = "1.10"
openhands-agent-server = "1.10"
openhands-tools = "1.10"
python-jose = { version = ">=3.3", extras = [ "cryptography" ] }
sqlalchemy = { extras = [ "asyncio" ], version = "^2.0.40" }
pg8000 = "^1.31.5"
@@ -17,6 +17,7 @@ from openhands.app_server.app_conversation.app_conversation_service_base import
)
from openhands.app_server.sandbox.sandbox_models import SandboxInfo
from openhands.app_server.user.user_context import UserContext
from openhands.sdk.context.skills import Skill
class MockUserInfo:
@@ -920,347 +921,251 @@ async def test_configure_git_user_settings_special_characters_in_name(mock_works
# =============================================================================
# Tests for load_and_merge_all_skills with org skills
# Tests for load_and_merge_all_skills (updated to use agent-server)
# =============================================================================
class TestLoadAndMergeAllSkillsWithOrgSkills:
"""Test load_and_merge_all_skills includes organization skills."""
class TestMergeSkills:
"""Test _merge_skills method."""
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
)
async def test_load_and_merge_includes_org_skills(
self,
mock_load_repo,
mock_load_org,
mock_load_user,
mock_load_global,
mock_load_sandbox,
):
"""Test that load_and_merge_all_skills loads and merges org skills."""
def test_merges_skills_with_no_duplicates(self):
"""Test merging skill lists with no duplicate names."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(
AppConversationServiceBase,
'__abstractmethods__',
set(),
):
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True,
user_context=mock_user_context,
init_git_in_empty_workspace=True, user_context=mock_user_context
)
sandbox = Mock(spec=SandboxInfo)
sandbox.exposed_urls = []
remote_workspace = AsyncMock()
skill1 = Mock(spec=Skill)
skill1.name = 'skill1'
skill2 = Mock(spec=Skill)
skill2.name = 'skill2'
skill3 = Mock(spec=Skill)
skill3.name = 'skill3'
# Create distinct mock skills for each source
sandbox_skill = Mock()
sandbox_skill.name = 'sandbox_skill'
global_skill = Mock()
global_skill.name = 'global_skill'
user_skill = Mock()
user_skill.name = 'user_skill'
org_skill = Mock()
org_skill.name = 'org_skill'
repo_skill = Mock()
repo_skill.name = 'repo_skill'
mock_load_sandbox.return_value = [sandbox_skill]
mock_load_global.return_value = [global_skill]
mock_load_user.return_value = [user_skill]
mock_load_org.return_value = [org_skill]
mock_load_repo.return_value = [repo_skill]
skill_lists = [[skill1], [skill2], [skill3]]
# Act
result = await service.load_and_merge_all_skills(
sandbox, remote_workspace, 'owner/repo', '/workspace'
)
result = service._merge_skills(skill_lists)
# Assert
assert len(result) == 5
assert len(result) == 3
names = {s.name for s in result}
assert names == {
'sandbox_skill',
'global_skill',
'user_skill',
'org_skill',
'repo_skill',
}
mock_load_org.assert_called_once_with(
remote_workspace, 'owner/repo', '/workspace', mock_user_context
)
assert names == {'skill1', 'skill2', 'skill3'}
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
)
async def test_load_and_merge_org_skills_precedence(
self,
mock_load_repo,
mock_load_org,
mock_load_user,
mock_load_global,
mock_load_sandbox,
):
"""Test that org skills have correct precedence (higher than user, lower than repo)."""
def test_merges_skills_with_duplicates_later_wins(self):
"""Test that later skill lists override earlier ones for duplicate names."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(
AppConversationServiceBase,
'__abstractmethods__',
set(),
):
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True,
user_context=mock_user_context,
init_git_in_empty_workspace=True, user_context=mock_user_context
)
sandbox = Mock(spec=SandboxInfo)
sandbox.exposed_urls = []
remote_workspace = AsyncMock()
skill1_v1 = Mock(spec=Skill)
skill1_v1.name = 'skill1'
skill1_v1.version = 'v1'
# Create skills with same name but different sources
user_skill = Mock()
user_skill.name = 'common_skill'
user_skill.source = 'user'
skill1_v2 = Mock(spec=Skill)
skill1_v2.name = 'skill1'
skill1_v2.version = 'v2'
org_skill = Mock()
org_skill.name = 'common_skill'
org_skill.source = 'org'
skill2 = Mock(spec=Skill)
skill2.name = 'skill2'
repo_skill = Mock()
repo_skill.name = 'common_skill'
repo_skill.source = 'repo'
mock_load_sandbox.return_value = []
mock_load_global.return_value = []
mock_load_user.return_value = [user_skill]
mock_load_org.return_value = [org_skill]
mock_load_repo.return_value = [repo_skill]
skill_lists = [[skill1_v1], [skill1_v2, skill2]]
# Act
result = await service.load_and_merge_all_skills(
sandbox, remote_workspace, 'owner/repo', '/workspace'
)
result = service._merge_skills(skill_lists)
# Assert
# Should have only one skill with repo source (highest precedence)
assert len(result) == 1
assert result[0].source == 'repo'
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
)
async def test_load_and_merge_org_skills_override_user_skills(
self,
mock_load_repo,
mock_load_org,
mock_load_user,
mock_load_global,
mock_load_sandbox,
):
"""Test that org skills override user skills for same name."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(
AppConversationServiceBase,
'__abstractmethods__',
set(),
):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True,
user_context=mock_user_context,
)
sandbox = Mock(spec=SandboxInfo)
sandbox.exposed_urls = []
remote_workspace = AsyncMock()
# Create skills with same name
user_skill = Mock()
user_skill.name = 'shared_skill'
user_skill.priority = 'low'
org_skill = Mock()
org_skill.name = 'shared_skill'
org_skill.priority = 'high'
mock_load_sandbox.return_value = []
mock_load_global.return_value = []
mock_load_user.return_value = [user_skill]
mock_load_org.return_value = [org_skill]
mock_load_repo.return_value = []
# Act
result = await service.load_and_merge_all_skills(
sandbox, remote_workspace, 'owner/repo', '/workspace'
)
# Assert
assert len(result) == 1
assert result[0].priority == 'high' # Org skill should win
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
)
async def test_load_and_merge_handles_org_skills_failure(
self,
mock_load_repo,
mock_load_org,
mock_load_user,
mock_load_global,
mock_load_sandbox,
):
"""Test that failure to load org skills doesn't break the overall process."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(
AppConversationServiceBase,
'__abstractmethods__',
set(),
):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True,
user_context=mock_user_context,
)
sandbox = Mock(spec=SandboxInfo)
sandbox.exposed_urls = []
remote_workspace = AsyncMock()
global_skill = Mock()
global_skill.name = 'global_skill'
repo_skill = Mock()
repo_skill.name = 'repo_skill'
mock_load_sandbox.return_value = []
mock_load_global.return_value = [global_skill]
mock_load_user.return_value = []
mock_load_org.return_value = [] # Org skills failed/empty
mock_load_repo.return_value = [repo_skill]
# Act
result = await service.load_and_merge_all_skills(
sandbox, remote_workspace, 'owner/repo', '/workspace'
)
# Assert
# Should still have skills from other sources
assert len(result) == 2
names = {s.name for s in result}
assert names == {'global_skill', 'repo_skill'}
skill1_result = next(s for s in result if s.name == 'skill1')
assert skill1_result.version == 'v2'
class TestLoadAndMergeAllSkills:
"""Test load_and_merge_all_skills method (updated to use agent-server)."""
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_sandbox_skills'
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_global_skills'
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_user_skills'
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_org_skills'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_repo_skills'
)
async def test_load_and_merge_no_selected_repository(
async def test_loads_skills_successfully(
self,
mock_load_repo,
mock_load_org,
mock_load_user,
mock_load_global,
mock_load_sandbox,
mock_build_sandbox_config,
mock_build_org_config,
mock_load_skills,
):
"""Test skill loading when no repository is selected."""
"""Test successfully loading skills from agent-server."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(
AppConversationServiceBase,
'__abstractmethods__',
set(),
):
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True,
user_context=mock_user_context,
init_git_in_empty_workspace=True, user_context=mock_user_context
)
mock_workspace = AsyncMock()
mock_workspace.working_dir = '/workspace'
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
sandbox = Mock(spec=SandboxInfo)
sandbox.exposed_urls = []
remote_workspace = AsyncMock()
exposed_url = ExposedUrl(
name='AGENT_SERVER', url='http://localhost:8000', port=8000
)
sandbox.exposed_urls = [exposed_url]
sandbox.session_api_key = 'test-api-key'
global_skill = Mock()
global_skill.name = 'global_skill'
skill1 = Mock(spec=Skill)
skill1.name = 'skill1'
skill2 = Mock(spec=Skill)
skill2.name = 'skill2'
mock_load_sandbox.return_value = []
mock_load_global.return_value = [global_skill]
mock_load_user.return_value = []
mock_load_org.return_value = []
mock_load_repo.return_value = []
mock_load_skills.return_value = [skill1, skill2]
mock_build_org_config.return_value = {'repository': 'owner/repo'}
mock_build_sandbox_config.return_value = {'exposed_urls': []}
# Act
result = await service.load_and_merge_all_skills(
sandbox, remote_workspace, None, '/workspace'
sandbox, 'owner/repo', '/workspace', 'http://localhost:8000'
)
# Assert
assert len(result) == 1
# Org skills should be called even with None repository
mock_load_org.assert_called_once_with(
remote_workspace, None, '/workspace', mock_user_context
assert len(result) == 2
assert result[0].name == 'skill1'
assert result[1].name == 'skill2'
mock_load_skills.assert_called_once()
call_kwargs = mock_load_skills.call_args[1]
assert call_kwargs['agent_server_url'] == 'http://localhost:8000'
assert call_kwargs['session_api_key'] == 'test-api-key'
assert call_kwargs['project_dir'] == '/workspace/repo'
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
)
async def test_returns_empty_list_when_no_agent_server_url(self, mock_load_skills):
"""Test returns empty list when agent-server URL is not available."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True, user_context=mock_user_context
)
AsyncMock()
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
sandbox = Mock(spec=SandboxInfo)
exposed_url = ExposedUrl(
name='VSCODE', url='http://localhost:8080', port=8080
)
sandbox.exposed_urls = [exposed_url]
# Act - pass empty string to simulate no agent server URL
# This should still call load_skills_from_agent_server but it will fail
result = await service.load_and_merge_all_skills(
sandbox, 'owner/repo', '/workspace', ''
)
# Assert - should return empty list when agent_server_url is empty
assert result == []
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
)
async def test_uses_working_dir_when_no_repository(
self,
mock_build_sandbox_config,
mock_build_org_config,
mock_load_skills,
):
"""Test uses working_dir as project_dir when no repository is selected."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True, user_context=mock_user_context
)
AsyncMock()
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
sandbox = Mock(spec=SandboxInfo)
exposed_url = ExposedUrl(
name='AGENT_SERVER', url='http://localhost:8000', port=8000
)
sandbox.exposed_urls = [exposed_url]
sandbox.session_api_key = 'test-key'
mock_load_skills.return_value = []
mock_build_org_config.return_value = None
mock_build_sandbox_config.return_value = None
# Act
await service.load_and_merge_all_skills(
sandbox, None, '/workspace', 'http://localhost:8000'
)
# Assert
call_kwargs = mock_load_skills.call_args[1]
assert call_kwargs['project_dir'] == '/workspace'
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.load_skills_from_agent_server'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.build_org_config'
)
@patch(
'openhands.app_server.app_conversation.app_conversation_service_base.build_sandbox_config'
)
async def test_handles_exception_gracefully(
self,
mock_build_sandbox_config,
mock_build_org_config,
mock_load_skills,
):
"""Test handles exceptions during skill loading."""
# Arrange
mock_user_context = Mock(spec=UserContext)
with patch.object(AppConversationServiceBase, '__abstractmethods__', set()):
service = AppConversationServiceBase(
init_git_in_empty_workspace=True, user_context=mock_user_context
)
AsyncMock()
from openhands.app_server.sandbox.sandbox_models import ExposedUrl
sandbox = Mock(spec=SandboxInfo)
exposed_url = ExposedUrl(
name='AGENT_SERVER', url='http://localhost:8000', port=8000
)
sandbox.exposed_urls = [exposed_url]
sandbox.session_api_key = 'test-key'
mock_load_skills.side_effect = Exception('Network error')
# Act
result = await service.load_and_merge_all_skills(
sandbox, 'owner/repo', '/workspace', 'http://localhost:8000'
)
# Assert
assert result == []
@@ -1165,6 +1165,50 @@ class TestLiveStatusAppConversationService:
)
self.mock_event_service.search_events.assert_called_once()
@pytest.mark.asyncio
async def test_export_conversation_calls_search_events_with_correct_parameter_name(
self,
):
"""Test that export_conversation calls search_events with 'conversation_id' parameter, not 'conversation_id__eq'.
This test verifies the fix for a bug where page_iterator was called with
conversation_id__eq instead of conversation_id, causing a TypeError since
the search_events method expects conversation_id as its parameter name.
"""
# Arrange
conversation_id = uuid4()
# Mock conversation info
mock_conversation_info = Mock(spec=AppConversationInfo)
mock_conversation_info.id = conversation_id
mock_conversation_info.model_dump_json = Mock(return_value='{}')
self.mock_app_conversation_info_service.get_app_conversation_info = AsyncMock(
return_value=mock_conversation_info
)
# Mock empty event page to simplify test
mock_event_page = Mock()
mock_event_page.items = []
mock_event_page.next_page_id = None
self.mock_event_service.search_events = AsyncMock(return_value=mock_event_page)
# Act
await self.service.export_conversation(conversation_id)
# Assert - Verify search_events was called with 'conversation_id', not 'conversation_id__eq'
self.mock_event_service.search_events.assert_called()
call_kwargs = self.mock_event_service.search_events.call_args[1]
assert 'conversation_id' in call_kwargs, (
"search_events should be called with 'conversation_id' parameter"
)
assert 'conversation_id__eq' not in call_kwargs, (
"search_events should NOT be called with 'conversation_id__eq' parameter"
)
assert call_kwargs['conversation_id'] == conversation_id
@pytest.mark.asyncio
async def test_export_conversation_large_pagination(self):
"""Test download with multiple pages of events."""
@@ -1288,7 +1332,7 @@ class TestLiveStatusAppConversationService:
task.sandbox_id = self.mock_sandbox.id
yield task
async def mock_run_setup_scripts(task, sandbox, workspace):
async def mock_run_setup_scripts(task, sandbox, workspace, agent_server_url):
yield task
self.service._wait_for_sandbox_start = mock_wait_for_sandbox
@@ -1742,3 +1786,855 @@ class TestLiveStatusAppConversationService:
stdio_server = mcp_servers['stdio-server']
assert stdio_server['command'] == 'npx'
assert stdio_server['env'] == {'TOKEN': 'value'}
class TestPluginHandling:
"""Test cases for plugin-related functionality in LiveStatusAppConversationService."""
def setup_method(self):
"""Set up test fixtures."""
# Create mock dependencies
self.mock_user_context = Mock(spec=UserContext)
self.mock_user_auth = Mock()
self.mock_user_context.user_auth = self.mock_user_auth
self.mock_jwt_service = Mock()
self.mock_sandbox_service = Mock()
self.mock_sandbox_spec_service = Mock()
self.mock_app_conversation_info_service = Mock()
self.mock_app_conversation_start_task_service = Mock()
self.mock_event_callback_service = Mock()
self.mock_event_service = Mock()
self.mock_httpx_client = Mock()
# Create service instance
self.service = LiveStatusAppConversationService(
init_git_in_empty_workspace=True,
user_context=self.mock_user_context,
app_conversation_info_service=self.mock_app_conversation_info_service,
app_conversation_start_task_service=self.mock_app_conversation_start_task_service,
event_callback_service=self.mock_event_callback_service,
event_service=self.mock_event_service,
sandbox_service=self.mock_sandbox_service,
sandbox_spec_service=self.mock_sandbox_spec_service,
jwt_service=self.mock_jwt_service,
sandbox_startup_timeout=30,
sandbox_startup_poll_frequency=1,
httpx_client=self.mock_httpx_client,
web_url='https://test.example.com',
openhands_provider_base_url='https://provider.example.com',
access_token_hard_timeout=None,
app_mode='test',
)
# Mock user info
self.mock_user = Mock()
self.mock_user.id = 'test_user_123'
self.mock_user.llm_model = 'gpt-4'
self.mock_user.llm_base_url = 'https://api.openai.com/v1'
self.mock_user.llm_api_key = 'test_api_key'
self.mock_user.confirmation_mode = False
self.mock_user.search_api_key = None
self.mock_user.condenser_max_size = None
self.mock_user.mcp_config = None
self.mock_user.security_analyzer = None
# Mock sandbox
self.mock_sandbox = Mock(spec=SandboxInfo)
self.mock_sandbox.id = uuid4()
self.mock_sandbox.status = SandboxStatus.RUNNING
def test_construct_initial_message_with_plugin_params_no_plugins(self):
"""Test _construct_initial_message_with_plugin_params with no plugins returns original message."""
from openhands.agent_server.models import SendMessageRequest, TextContent
# Test with None initial message and None plugins
result = self.service._construct_initial_message_with_plugin_params(None, None)
assert result is None
# Test with None initial message and empty plugins list
result = self.service._construct_initial_message_with_plugin_params(None, [])
assert result is None
# Test with initial message but None plugins
initial_msg = SendMessageRequest(content=[TextContent(text='Hello world')])
result = self.service._construct_initial_message_with_plugin_params(
initial_msg, None
)
assert result is initial_msg
def test_construct_initial_message_with_plugin_params_no_params(self):
"""Test _construct_initial_message_with_plugin_params with plugins but no parameters."""
from openhands.agent_server.models import SendMessageRequest, TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Plugin with no parameters
plugins = [PluginSpec(source='github:owner/repo')]
# Test with None initial message
result = self.service._construct_initial_message_with_plugin_params(
None, plugins
)
assert result is None
# Test with initial message
initial_msg = SendMessageRequest(content=[TextContent(text='Hello world')])
result = self.service._construct_initial_message_with_plugin_params(
initial_msg, plugins
)
assert result is initial_msg
def test_construct_initial_message_with_plugin_params_creates_new_message(self):
"""Test _construct_initial_message_with_plugin_params creates message when no initial message."""
from openhands.agent_server.models import TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugins = [
PluginSpec(
source='github:owner/repo',
parameters={'api_key': 'test123', 'debug': True},
)
]
result = self.service._construct_initial_message_with_plugin_params(
None, plugins
)
assert result is not None
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert 'Plugin Configuration Parameters:' in result.content[0].text
assert '- api_key: test123' in result.content[0].text
assert '- debug: True' in result.content[0].text
assert result.run is True
def test_construct_initial_message_with_plugin_params_appends_to_message(self):
"""Test _construct_initial_message_with_plugin_params appends to existing message."""
from openhands.agent_server.models import SendMessageRequest, TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
initial_msg = SendMessageRequest(
content=[TextContent(text='Please analyze this codebase')],
run=False,
)
plugins = [
PluginSpec(
source='github:owner/repo',
ref='v1.0.0',
parameters={'target_dir': '/src', 'verbose': True},
)
]
result = self.service._construct_initial_message_with_plugin_params(
initial_msg, plugins
)
assert result is not None
assert len(result.content) == 1
text = result.content[0].text
assert text.startswith('Please analyze this codebase')
assert 'Plugin Configuration Parameters:' in text
assert '- target_dir: /src' in text
assert '- verbose: True' in text
assert result.run is False
def test_construct_initial_message_with_plugin_params_preserves_role(self):
"""Test _construct_initial_message_with_plugin_params preserves message role."""
from openhands.agent_server.models import SendMessageRequest, TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
initial_msg = SendMessageRequest(
role='system',
content=[TextContent(text='System message')],
)
plugins = [PluginSpec(source='github:owner/repo', parameters={'key': 'value'})]
result = self.service._construct_initial_message_with_plugin_params(
initial_msg, plugins
)
assert result is not None
assert result.role == 'system'
def test_construct_initial_message_with_plugin_params_empty_content(self):
"""Test _construct_initial_message_with_plugin_params handles empty content list."""
from openhands.agent_server.models import SendMessageRequest, TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
initial_msg = SendMessageRequest(content=[])
plugins = [PluginSpec(source='github:owner/repo', parameters={'key': 'value'})]
result = self.service._construct_initial_message_with_plugin_params(
initial_msg, plugins
)
assert result is not None
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert 'Plugin Configuration Parameters:' in result.content[0].text
def test_construct_initial_message_with_multiple_plugins(self):
"""Test _construct_initial_message_with_plugin_params handles multiple plugins."""
from openhands.agent_server.models import TextContent
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugins = [
PluginSpec(
source='github:owner/plugin1',
parameters={'key1': 'value1'},
),
PluginSpec(
source='github:owner/plugin2',
parameters={'key2': 'value2'},
),
]
result = self.service._construct_initial_message_with_plugin_params(
None, plugins
)
assert result is not None
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
text = result.content[0].text
assert 'Plugin Configuration Parameters:' in text
# Multiple plugins should show grouped by plugin name
assert 'plugin1' in text
assert 'plugin2' in text
assert 'key1: value1' in text
assert 'key2: value2' in text
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
)
async def test_finalize_conversation_request_with_plugins(
self, mock_experiment_manager
):
"""Test _finalize_conversation_request passes plugins list to StartConversationRequest."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Arrange
mock_agent = Mock(spec=Agent)
mock_llm = Mock(spec=LLM)
mock_llm.model = 'gpt-4'
mock_llm.usage_id = 'agent'
mock_updated_agent = Mock(spec=Agent)
mock_updated_agent.llm = mock_llm
mock_updated_agent.condenser = None
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
mock_updated_agent
)
workspace = LocalWorkspace(working_dir='/test')
secrets = {'test': StaticSecret(value='secret')}
plugins = [
PluginSpec(
source='github:owner/my-plugin',
ref='v1.0.0',
parameters={'api_key': 'test123'},
)
]
# Act
result = await self.service._finalize_conversation_request(
mock_agent,
None,
self.mock_user,
workspace,
None,
secrets,
self.mock_sandbox,
None,
None,
'/test/dir',
plugins=plugins,
)
# Assert
assert isinstance(result, StartConversationRequest)
assert result.plugins is not None
assert len(result.plugins) == 1
assert result.plugins[0].source == 'github:owner/my-plugin'
assert result.plugins[0].ref == 'v1.0.0'
# Also verify initial message contains plugin params
assert result.initial_message is not None
assert (
'Plugin Configuration Parameters:' in result.initial_message.content[0].text
)
assert '- api_key: test123' in result.initial_message.content[0].text
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
)
async def test_finalize_conversation_request_without_plugins(
self, mock_experiment_manager
):
"""Test _finalize_conversation_request without plugins sets plugins to None."""
# Arrange
mock_agent = Mock(spec=Agent)
mock_llm = Mock(spec=LLM)
mock_llm.model = 'gpt-4'
mock_llm.usage_id = 'agent'
mock_updated_agent = Mock(spec=Agent)
mock_updated_agent.llm = mock_llm
mock_updated_agent.condenser = None
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
mock_updated_agent
)
workspace = LocalWorkspace(working_dir='/test')
secrets = {}
# Act
result = await self.service._finalize_conversation_request(
mock_agent,
None,
self.mock_user,
workspace,
None,
secrets,
self.mock_sandbox,
None,
None,
'/test/dir',
plugins=None,
)
# Assert
assert isinstance(result, StartConversationRequest)
assert result.plugins is None
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
)
async def test_finalize_conversation_request_plugin_without_ref(
self, mock_experiment_manager
):
"""Test _finalize_conversation_request with plugin that has no ref."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Arrange
mock_agent = Mock(spec=Agent)
mock_llm = Mock(spec=LLM)
mock_llm.model = 'gpt-4'
mock_llm.usage_id = 'agent'
mock_updated_agent = Mock(spec=Agent)
mock_updated_agent.llm = mock_llm
mock_updated_agent.condenser = None
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
mock_updated_agent
)
workspace = LocalWorkspace(working_dir='/test')
secrets = {}
# Plugin without ref or parameters
plugins = [PluginSpec(source='github:owner/my-plugin')]
# Act
result = await self.service._finalize_conversation_request(
mock_agent,
None,
self.mock_user,
workspace,
None,
secrets,
self.mock_sandbox,
None,
None,
'/test/dir',
plugins=plugins,
)
# Assert
assert isinstance(result, StartConversationRequest)
assert result.plugins is not None
assert len(result.plugins) == 1
assert result.plugins[0].source == 'github:owner/my-plugin'
assert result.plugins[0].ref is None
# No parameters, so initial message should be None
assert result.initial_message is None
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
)
async def test_finalize_conversation_request_plugin_with_repo_path(
self, mock_experiment_manager
):
"""Test _finalize_conversation_request passes repo_path to PluginSource."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Arrange
mock_agent = Mock(spec=Agent)
mock_llm = Mock(spec=LLM)
mock_llm.model = 'gpt-4'
mock_llm.usage_id = 'agent'
mock_updated_agent = Mock(spec=Agent)
mock_updated_agent.llm = mock_llm
mock_updated_agent.condenser = None
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
mock_updated_agent
)
workspace = LocalWorkspace(working_dir='/test')
secrets = {}
# Plugin with repo_path (for marketplace repos containing multiple plugins)
plugins = [
PluginSpec(
source='github:owner/marketplace-repo',
ref='main',
repo_path='plugins/city-weather',
)
]
# Act
result = await self.service._finalize_conversation_request(
mock_agent,
None,
self.mock_user,
workspace,
None,
secrets,
self.mock_sandbox,
None,
None,
'/test/dir',
plugins=plugins,
)
# Assert
assert isinstance(result, StartConversationRequest)
assert result.plugins is not None
assert len(result.plugins) == 1
assert result.plugins[0].source == 'github:owner/marketplace-repo'
assert result.plugins[0].ref == 'main'
assert result.plugins[0].repo_path == 'plugins/city-weather'
@pytest.mark.asyncio
@patch(
'openhands.app_server.app_conversation.live_status_app_conversation_service.ExperimentManagerImpl'
)
async def test_finalize_conversation_request_multiple_plugins(
self, mock_experiment_manager
):
"""Test _finalize_conversation_request with multiple plugins."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Arrange
mock_agent = Mock(spec=Agent)
mock_llm = Mock(spec=LLM)
mock_llm.model = 'gpt-4'
mock_llm.usage_id = 'agent'
mock_updated_agent = Mock(spec=Agent)
mock_updated_agent.llm = mock_llm
mock_updated_agent.condenser = None
mock_experiment_manager.run_agent_variant_tests__v1.return_value = (
mock_updated_agent
)
workspace = LocalWorkspace(working_dir='/test')
secrets = {}
# Multiple plugins
plugins = [
PluginSpec(source='github:owner/security-plugin', ref='v2.0.0'),
PluginSpec(
source='github:owner/monorepo',
repo_path='plugins/logging',
),
PluginSpec(source='/local/path/to/plugin'),
]
# Act
result = await self.service._finalize_conversation_request(
mock_agent,
None,
self.mock_user,
workspace,
None,
secrets,
self.mock_sandbox,
None,
None,
'/test/dir',
plugins=plugins,
)
# Assert
assert isinstance(result, StartConversationRequest)
assert result.plugins is not None
assert len(result.plugins) == 3
assert result.plugins[0].source == 'github:owner/security-plugin'
assert result.plugins[0].ref == 'v2.0.0'
assert result.plugins[1].source == 'github:owner/monorepo'
assert result.plugins[1].repo_path == 'plugins/logging'
assert result.plugins[2].source == '/local/path/to/plugin'
@pytest.mark.asyncio
async def test_build_start_conversation_request_for_user_with_plugins(self):
"""Test _build_start_conversation_request_for_user passes plugins to finalize method."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Arrange
self.mock_user_context.get_user_info.return_value = self.mock_user
self.mock_user_context.get_secrets.return_value = {}
self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None)
self.mock_user_context.get_mcp_api_key.return_value = None
plugins = [
PluginSpec(
source='https://github.com/org/plugin.git',
ref='main',
parameters={'config_file': 'custom.yaml'},
)
]
# Mock _finalize_conversation_request to capture the call
mock_finalize = AsyncMock(return_value=Mock(spec=StartConversationRequest))
self.service._finalize_conversation_request = mock_finalize
# Act
await self.service._build_start_conversation_request_for_user(
self.mock_sandbox,
None,
None,
None,
'/workspace',
plugins=plugins,
)
# Assert
mock_finalize.assert_called_once()
call_kwargs = mock_finalize.call_args.kwargs
assert call_kwargs['plugins'] == plugins
@pytest.mark.asyncio
async def test_build_start_conversation_request_for_user_without_plugins(self):
"""Test _build_start_conversation_request_for_user works without plugins."""
# Arrange
self.mock_user_context.get_user_info.return_value = self.mock_user
self.mock_user_context.get_secrets.return_value = {}
self.mock_user_context.get_provider_tokens = AsyncMock(return_value=None)
self.mock_user_context.get_mcp_api_key.return_value = None
# Mock _finalize_conversation_request
mock_finalize = AsyncMock(return_value=Mock(spec=StartConversationRequest))
self.service._finalize_conversation_request = mock_finalize
# Act
await self.service._build_start_conversation_request_for_user(
self.mock_sandbox,
None,
None,
None,
'/workspace',
)
# Assert
mock_finalize.assert_called_once()
call_kwargs = mock_finalize.call_args.kwargs
assert call_kwargs.get('plugins') is None
class TestPluginSpecModel:
"""Test cases for the PluginSpec model."""
def test_plugin_spec_with_all_fields(self):
"""Test PluginSpec with all fields provided."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(
source='github:owner/repo',
ref='v1.0.0',
repo_path='plugins/my-plugin',
parameters={'key1': 'value1', 'key2': 123, 'key3': True},
)
assert plugin.source == 'github:owner/repo'
assert plugin.ref == 'v1.0.0'
assert plugin.repo_path == 'plugins/my-plugin'
assert plugin.parameters == {'key1': 'value1', 'key2': 123, 'key3': True}
def test_plugin_spec_with_only_source(self):
"""Test PluginSpec with only source provided."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='https://github.com/owner/repo.git')
assert plugin.source == 'https://github.com/owner/repo.git'
assert plugin.ref is None
assert plugin.repo_path is None
assert plugin.parameters is None
def test_plugin_spec_serialization(self):
"""Test PluginSpec serialization to JSON."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(
source='github:owner/repo',
ref='main',
repo_path='plugins/my-plugin',
parameters={'debug': True},
)
json_data = plugin.model_dump()
assert json_data == {
'source': 'github:owner/repo',
'ref': 'main',
'repo_path': 'plugins/my-plugin',
'parameters': {'debug': True},
}
def test_plugin_spec_deserialization(self):
"""Test PluginSpec deserialization from dict."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
data = {
'source': 'github:owner/repo',
'ref': 'v2.0.0',
'repo_path': 'plugins/weather',
'parameters': {'timeout': 30},
}
plugin = PluginSpec.model_validate(data)
assert plugin.source == 'github:owner/repo'
assert plugin.ref == 'v2.0.0'
assert plugin.repo_path == 'plugins/weather'
assert plugin.parameters == {'timeout': 30}
def test_plugin_spec_display_name_github_format(self):
"""Test display_name extracts repo name from github:owner/repo format."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='github:owner/my-plugin')
assert plugin.display_name == 'my-plugin'
def test_plugin_spec_display_name_git_url(self):
"""Test display_name extracts repo name from git URL."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='https://github.com/owner/repo.git')
assert plugin.display_name == 'repo.git'
def test_plugin_spec_display_name_local_path(self):
"""Test display_name extracts directory name from local path."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='/local/path/to/plugin')
assert plugin.display_name == 'plugin'
def test_plugin_spec_display_name_no_slash(self):
"""Test display_name returns source as-is when no slash present."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='local-plugin')
assert plugin.display_name == 'local-plugin'
def test_plugin_spec_format_params_as_text(self):
"""Test format_params_as_text formats parameters as text."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(
source='github:owner/repo',
parameters={'key1': 'value1', 'key2': 123},
)
result = plugin.format_params_as_text()
assert result == '- key1: value1\n- key2: 123'
def test_plugin_spec_format_params_as_text_with_indent(self):
"""Test format_params_as_text with custom indent."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(
source='github:owner/repo',
parameters={'debug': True},
)
result = plugin.format_params_as_text(indent=' ')
assert result == ' - debug: True'
def test_plugin_spec_format_params_as_text_no_params(self):
"""Test format_params_as_text returns None when no parameters."""
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
plugin = PluginSpec(source='github:owner/repo')
assert plugin.format_params_as_text() is None
def test_plugin_spec_inherits_repo_path_validation(self):
"""Test PluginSpec inherits validation from SDK's PluginSource."""
import pytest
from openhands.app_server.app_conversation.app_conversation_models import (
PluginSpec,
)
# Should reject absolute paths
with pytest.raises(ValueError, match='must be relative'):
PluginSpec(source='github:owner/repo', repo_path='/absolute/path')
# Should reject parent traversal
with pytest.raises(ValueError, match="cannot contain '..'"):
PluginSpec(source='github:owner/repo', repo_path='../parent/path')
class TestAppConversationStartRequestWithPlugins:
"""Test cases for AppConversationStartRequest with plugins field."""
def test_start_request_with_plugins(self):
"""Test AppConversationStartRequest with plugins field."""
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
PluginSpec,
)
plugins = [
PluginSpec(
source='github:owner/my-plugin',
ref='v1.0.0',
parameters={'api_key': 'test'},
)
]
request = AppConversationStartRequest(
title='Test conversation',
plugins=plugins,
)
assert request.plugins is not None
assert len(request.plugins) == 1
assert request.plugins[0].source == 'github:owner/my-plugin'
assert request.plugins[0].ref == 'v1.0.0'
assert request.plugins[0].parameters == {'api_key': 'test'}
def test_start_request_without_plugins(self):
"""Test AppConversationStartRequest without plugins field (backwards compatible)."""
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
)
request = AppConversationStartRequest(
title='Test conversation',
)
assert request.plugins is None
def test_start_request_serialization_with_plugins(self):
"""Test AppConversationStartRequest serialization includes plugins."""
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
PluginSpec,
)
plugins = [PluginSpec(source='github:owner/repo')]
request = AppConversationStartRequest(plugins=plugins)
json_data = request.model_dump()
assert 'plugins' in json_data
assert len(json_data['plugins']) == 1
assert json_data['plugins'][0]['source'] == 'github:owner/repo'
def test_start_request_deserialization_with_plugins(self):
"""Test AppConversationStartRequest deserialization from JSON with plugins."""
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
)
data = {
'title': 'Test',
'plugins': [
{
'source': 'github:owner/plugin',
'ref': 'main',
'parameters': {'key': 'value'},
},
],
}
request = AppConversationStartRequest.model_validate(data)
assert request.plugins is not None
assert len(request.plugins) == 1
assert request.plugins[0].source == 'github:owner/plugin'
assert request.plugins[0].ref == 'main'
assert request.plugins[0].parameters == {'key': 'value'}
def test_start_request_with_multiple_plugins(self):
"""Test AppConversationStartRequest with multiple plugins."""
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
PluginSpec,
)
plugins = [
PluginSpec(source='github:owner/plugin1', ref='v1.0.0'),
PluginSpec(source='github:owner/plugin2', repo_path='plugins/sub'),
PluginSpec(source='/local/path'),
]
request = AppConversationStartRequest(
title='Test conversation',
plugins=plugins,
)
assert request.plugins is not None
assert len(request.plugins) == 3
assert request.plugins[0].source == 'github:owner/plugin1'
assert request.plugins[1].repo_path == 'plugins/sub'
assert request.plugins[2].source == '/local/path'
File diff suppressed because it is too large Load Diff