Compare commits

..

3 Commits

Author SHA1 Message Date
mamoodi
a77b05509e Release 1.5.0 2026-03-10 17:00:51 -04:00
Tim O'Farrell
db40eb1e94 Using the web_url where it is configured rather than the request.url (#13319)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-10 13:11:33 -06:00
Hiep Le
debbaae385 fix(backend): inherit organization llm settings for new members (#13330) 2026-03-11 01:28:46 +07:00
30 changed files with 726 additions and 1895 deletions

View File

@@ -1,71 +0,0 @@
"""Entry point for the automation executor.
Usage: python -m run_automation_executor
This runs as a Kubernetes Deployment (long-running). It polls the automation_events
inbox, matches events to automations, claims and executes runs, and monitors
conversation completion.
Environment variables:
OPENHANDS_API_URL Base URL for the V1 API (default: http://openhands-service:3000)
MAX_CONCURRENT_RUNS Max concurrent runs per executor (default: 5)
RUN_TIMEOUT_SECONDS Max time for a single run (default: 7200)
POLL_INTERVAL_SECONDS Fallback poll interval (default: 30)
HEARTBEAT_INTERVAL_SECONDS Heartbeat update interval (default: 60)
"""
import asyncio
import logging
import signal
import sys
logger = logging.getLogger('saas.automation.executor')
def _setup_logging() -> None:
"""Configure logging, deferring to enterprise logger if available."""
try:
from server.logger import setup_all_loggers
setup_all_loggers()
except ImportError:
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(name)s %(levelname)s %(message)s',
stream=sys.stdout,
)
def _install_signal_handlers(loop: asyncio.AbstractEventLoop) -> None:
"""Install signal handlers for graceful shutdown."""
from services.automation_executor import request_shutdown
def _handle_signal(signum: int, _frame: object) -> None:
sig_name = signal.Signals(signum).name
logger.info('Received %s, initiating graceful shutdown...', sig_name)
request_shutdown()
for sig in (signal.SIGTERM, signal.SIGINT):
signal.signal(sig, _handle_signal)
async def main() -> None:
from services.automation_executor import executor_main
await executor_main()
if __name__ == '__main__':
_setup_logging()
loop = asyncio.new_event_loop()
_install_signal_handlers(loop)
logger.info('Starting automation executor')
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
logger.info('Interrupted by user')
finally:
loop.close()
logger.info('Automation executor process exiting')

View File

@@ -12,11 +12,8 @@ from server.auth.auth_error import (
)
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
from server.auth.saas_user_auth import SaasUserAuth, token_manager
from server.routes.auth import (
get_cookie_domain,
get_cookie_samesite,
set_response_cookie,
)
from server.routes.auth import set_response_cookie
from server.utils.url_utils import get_cookie_domain, get_cookie_samesite
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth.user_auth import AuthType, UserAuth, get_user_auth
@@ -93,8 +90,8 @@ class SetAuthCookieMiddleware:
if keycloak_auth_cookie:
response.delete_cookie(
key='keycloak_auth',
domain=get_cookie_domain(request),
samesite=get_cookie_samesite(request),
domain=get_cookie_domain(),
samesite=get_cookie_samesite(),
)
return response

View File

@@ -3,7 +3,7 @@ import json
import uuid
import warnings
from datetime import datetime, timezone
from typing import Annotated, Literal, Optional, cast
from typing import Annotated, Optional, cast
from urllib.parse import quote, urlencode
from uuid import UUID as parse_uuid
@@ -27,7 +27,7 @@ from server.auth.user.user_authorizer import (
depends_user_authorizer,
)
from server.config import sign_token
from server.constants import IS_FEATURE_ENV
from server.constants import IS_FEATURE_ENV, IS_LOCAL_ENV
from server.routes.event_webhook import _get_session_api_key, _get_user_id
from server.services.org_invitation_service import (
EmailMismatchError,
@@ -37,12 +37,12 @@ from server.services.org_invitation_service import (
UserAlreadyMemberError,
)
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from server.utils.url_utils import get_cookie_domain, get_cookie_samesite, get_web_url
from sqlalchemy import select
from storage.database import a_session_maker
from storage.user import User
from storage.user_store import UserStore
from openhands.app_server.config import get_global_config
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import ProviderType, TokenResponse
@@ -77,7 +77,7 @@ def set_response_cookie(
signed_token = sign_token(cookie_data, config.jwt_secret.get_secret_value()) # type: ignore
# Set secure cookie with signed token
domain = get_cookie_domain(request)
domain = get_cookie_domain()
if domain:
response.set_cookie(
key='keycloak_auth',
@@ -85,7 +85,7 @@ def set_response_cookie(
domain=domain,
httponly=True,
secure=secure,
samesite=get_cookie_samesite(request),
samesite=get_cookie_samesite(),
)
else:
response.set_cookie(
@@ -93,30 +93,10 @@ def set_response_cookie(
value=signed_token,
httponly=True,
secure=secure,
samesite=get_cookie_samesite(request),
samesite=get_cookie_samesite(),
)
def get_cookie_domain(request: Request) -> str | None:
# for now just use the full hostname except for staging stacks.
return (
None
if not request.url.hostname
or request.url.hostname.endswith('staging.all-hands.dev')
else request.url.hostname
)
def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
# for localhost and feature/staging stacks we set it to 'lax' as the cookie domain won't allow 'strict'
return (
'lax'
if request.url.hostname == 'localhost'
or (request.url.hostname or '').endswith('staging.all-hands.dev')
else 'strict'
)
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
@@ -140,19 +120,6 @@ def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None
return state, None, None
# Keep alias for backward compatibility
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
"""Extract redirect URL and reCAPTCHA token from OAuth state.
Deprecated: Use _extract_oauth_state instead.
Returns:
Tuple of (redirect_url, recaptcha_token). Token may be None.
"""
redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
return redirect_url, recaptcha_token
@oauth_router.get('/keycloak/callback')
async def keycloak_callback(
request: Request,
@@ -183,10 +150,7 @@ async def keycloak_callback(
detail='Missing code in request params',
)
web_url = get_global_config().web_url
if not web_url:
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
web_url = f'{scheme}://{request.url.netloc}'
web_url = get_web_url(request)
redirect_uri = web_url + request.url.path
(
@@ -313,7 +277,9 @@ async def keycloak_callback(
else:
raise
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
verification_redirect_url = (
f'{web_url}/login?email_verification_required=true&user_id={user_id}'
)
if rate_limited:
verification_redirect_url = f'{verification_redirect_url}&rate_limited=true'
@@ -474,9 +440,7 @@ async def keycloak_callback(
# If the user hasn't accepted the TOS, redirect to the TOS page
if not has_accepted_tos:
encoded_redirect_url = quote(redirect_url, safe='')
tos_redirect_url = (
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
)
tos_redirect_url = f'{web_url}/accept-tos?redirect_url={encoded_redirect_url}'
if invitation_token:
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
response = RedirectResponse(tos_redirect_url, status_code=302)
@@ -508,10 +472,9 @@ async def keycloak_offline_callback(code: str, state: str, request: Request):
status_code=status.HTTP_400_BAD_REQUEST,
content={'error': 'Missing code in request params'},
)
scheme = 'https'
if request.url.hostname == 'localhost':
scheme = 'http'
redirect_uri = f'{scheme}://{request.url.netloc}{request.url.path}'
web_url = get_web_url(request)
redirect_uri = web_url + request.url.path
logger.debug(f'code: {code}, redirect_uri: {redirect_uri}')
(
@@ -533,15 +496,14 @@ async def keycloak_offline_callback(code: str, state: str, request: Request):
)
redirect_url, _, _ = _extract_oauth_state(state)
return RedirectResponse(
redirect_url if redirect_url else request.base_url, status_code=302
)
return RedirectResponse(redirect_url if redirect_url else web_url, status_code=302)
@oauth_router.get('/github/callback')
async def github_dummy_callback(request: Request):
"""Callback for GitHub that just forwards the user to the app base URL."""
return RedirectResponse(request.base_url, status_code=302)
web_url = get_web_url(request)
return RedirectResponse(web_url, status_code=302)
@api_router.post('/authenticate')
@@ -563,8 +525,8 @@ async def authenticate(request: Request):
if keycloak_auth_cookie:
response.delete_cookie(
key='keycloak_auth',
domain=get_cookie_domain(request),
samesite=get_cookie_samesite(request),
domain=get_cookie_domain(),
samesite=get_cookie_samesite(),
)
return response
@@ -588,7 +550,8 @@ async def accept_tos(request: Request):
# Get redirect URL from request body
body = await request.json()
redirect_url = body.get('redirect_url', str(request.base_url))
web_url = get_web_url(request)
redirect_url = body.get('redirect_url', str(web_url))
# Update user settings with TOS acceptance
accepted_tos: datetime = datetime.now(timezone.utc).replace(tzinfo=None)
@@ -618,7 +581,7 @@ async def accept_tos(request: Request):
response=response,
keycloak_access_token=access_token.get_secret_value(),
keycloak_refresh_token=refresh_token.get_secret_value(),
secure=False if request.url.hostname == 'localhost' else True,
secure=not IS_LOCAL_ENV,
accepted_tos=True,
)
return response
@@ -635,8 +598,8 @@ async def logout(request: Request):
# Always delete the cookie regardless of what happens
response.delete_cookie(
key='keycloak_auth',
domain=get_cookie_domain(request),
samesite=get_cookie_samesite(request),
domain=get_cookie_domain(),
samesite=get_cookie_samesite(),
)
# Try to properly logout from Keycloak, but don't fail if it doesn't work

View File

@@ -11,8 +11,8 @@ from integrations import stripe_service
from pydantic import BaseModel
from server.constants import STRIPE_API_KEY
from server.logger import logger
from server.utils.url_utils import get_web_url
from sqlalchemy import select
from starlette.datastructures import URL
from storage.billing_session import BillingSession
from storage.database import a_session_maker
from storage.lite_llm_manager import LiteLlmManager
@@ -151,7 +151,7 @@ async def create_customer_setup_session(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Could not find or create customer for user',
)
base_url = _get_base_url(request)
base_url = get_web_url(request)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_info['customer_id'],
mode='setup',
@@ -170,7 +170,7 @@ async def create_checkout_session(
user_id: str = Depends(get_user_id),
) -> CreateBillingSessionResponse:
await validate_billing_enabled()
base_url = _get_base_url(request)
base_url = get_web_url(request)
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
if not customer_info:
raise HTTPException(
@@ -198,8 +198,8 @@ async def create_checkout_session(
saved_payment_method_options={
'payment_method_save': 'enabled',
},
success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
success_url=f'{base_url}/api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{base_url}/api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
)
logger.info(
'created_stripe_checkout_session',
@@ -300,7 +300,7 @@ async def success_callback(session_id: str, request: Request):
await session.commit()
return RedirectResponse(
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
f'{get_web_url(request)}/settings/billing?checkout=success', status_code=302
)
@@ -325,17 +325,9 @@ async def cancel_callback(session_id: str, request: Request):
)
billing_session.status = 'cancelled'
billing_session.updated_at = datetime.now(UTC)
session.merge(billing_session)
await session.merge(billing_session)
await session.commit()
return RedirectResponse(
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
f'{get_web_url(request)}/settings/billing?checkout=cancel', status_code=302
)
def _get_base_url(request: Request) -> URL:
# Never send any part of the credit card process over a non secure connection
base_url = request.base_url
if base_url.hostname != 'localhost':
base_url = base_url.replace(scheme='https')
return base_url

View File

@@ -7,8 +7,10 @@ from pydantic import BaseModel, field_validator
from server.auth.constants import KEYCLOAK_CLIENT_ID
from server.auth.keycloak_manager import get_keycloak_admin
from server.auth.saas_user_auth import SaasUserAuth
from server.constants import IS_LOCAL_ENV
from server.routes.auth import set_response_cookie
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from server.utils.url_utils import get_web_url
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
@@ -87,7 +89,7 @@ async def update_email(
response=response,
keycloak_access_token=user_auth.access_token.get_secret_value(),
keycloak_refresh_token=user_auth.refresh_token.get_secret_value(),
secure=False if request.url.hostname == 'localhost' else True,
secure=not IS_LOCAL_ENV,
accepted_tos=user_auth.accepted_tos or False,
)
@@ -156,8 +158,8 @@ async def verified_email(request: Request):
await user_auth.refresh() # refresh so access token has updated email
user_auth.email_verified = True
await UserStore.update_user_email(user_id=user_auth.user_id, email_verified=True)
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
redirect_uri = f'{scheme}://{request.url.netloc}/settings/user'
redirect_uri = f'{get_web_url(request)}/settings/user'
response = RedirectResponse(redirect_uri, status_code=302)
# need to set auth cookie to the new tokens
@@ -180,11 +182,10 @@ async def verified_email(request: Request):
async def verify_email(request: Request, user_id: str, is_auth_flow: bool = False):
keycloak_admin = get_keycloak_admin()
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
if is_auth_flow:
redirect_uri = f'{scheme}://{request.url.netloc}/login?email_verified=true'
redirect_uri = f'{get_web_url(request)}/login?email_verified=true'
else:
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
redirect_uri = f'{get_web_url(request)}/api/email/verified'
logger.info(f'Redirect URI: {redirect_uri}')
await keycloak_admin.a_send_verify_email(
user_id=user_id,

View File

@@ -6,6 +6,7 @@ from typing import Optional
from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from server.utils.url_utils import get_web_url
from storage.api_key_store import ApiKeyStore
from storage.device_code_store import DeviceCodeStore
@@ -93,7 +94,7 @@ async def device_authorization(
expires_in=DEVICE_CODE_EXPIRES_IN,
)
base_url = str(http_request.base_url).rstrip('/')
base_url = get_web_url(http_request)
verification_uri = f'{base_url}/oauth/device/verify'
verification_uri_complete = (
f'{verification_uri}?user_code={device_code_entry.user_code}'

View File

@@ -365,14 +365,12 @@ class OrgInvitationService:
'Failed to set up organization access. Please try again.'
)
# Step 5: Add user to organization
from storage.org_member_store import OrgMemberStore as OMS
org_member_kwargs = OMS.get_kwargs_from_settings(settings)
# Don't override with org defaults - use invitation-specified role
org_member_kwargs.pop('llm_model', None)
org_member_kwargs.pop('llm_base_url', None)
# Step 4.5: Fetch organization to get its LLM settings
org = await OrgStore.get_org_by_id(invitation.org_id)
if not org:
raise InvitationInvalidError('Organization not found')
# Step 5: Add user to organization with inherited org LLM settings
# Get the llm_api_key as string (it's SecretStr | None in Settings)
llm_api_key = (
settings.llm_api_key.get_secret_value() if settings.llm_api_key else ''
@@ -384,6 +382,9 @@ class OrgInvitationService:
role_id=invitation.role_id,
llm_api_key=llm_api_key,
status='active',
llm_model=org.default_llm_model,
llm_base_url=org.default_llm_base_url,
max_iterations=org.default_max_iterations,
)
# Step 6: Mark invitation as accepted

View File

@@ -0,0 +1,38 @@
from typing import Literal
from fastapi import Request
from server.constants import IS_FEATURE_ENV, IS_LOCAL_ENV, IS_STAGING_ENV
from starlette.datastructures import URL
from openhands.app_server.config import get_global_config
def get_web_url(request: Request):
web_url = get_global_config().web_url
if not web_url:
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
web_url = f'{scheme}://{request.url.netloc}'
else:
web_url = web_url.rstrip('/')
return web_url
def get_cookie_domain() -> str | None:
config = get_global_config()
web_url = config.web_url
# for now just use the full hostname except for staging stacks.
return (
URL(web_url).hostname
if web_url and not (IS_FEATURE_ENV or IS_STAGING_ENV or IS_LOCAL_ENV)
else None
)
def get_cookie_samesite() -> Literal['lax', 'strict']:
# for localhost and feature/staging stacks we set it to 'lax' as the cookie domain won't allow 'strict'
web_url = get_global_config().web_url
return (
'strict'
if web_url and not (IS_FEATURE_ENV or IS_STAGING_ENV or IS_LOCAL_ENV)
else 'lax'
)

View File

@@ -1,555 +0,0 @@
"""Automation executor — processes events, claims and executes runs.
The executor is a long-running process with three phases:
1. Process inbox: match NEW events to automations, create PENDING runs
2. Claim and execute: claim PENDING runs, submit to V1 API, heartbeat
3. Stale recovery: recover RUNNING runs with expired heartbeats
"""
import asyncio
import logging
import os
import socket
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from services.openhands_api_client import OpenHandsAPIClient
from sqlalchemy import or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from storage.automation import Automation, AutomationRun
from storage.automation_event import AutomationEvent
logger = logging.getLogger('saas.automation.executor')
# Environment-configurable settings
POLL_INTERVAL_SECONDS = float(os.getenv('POLL_INTERVAL_SECONDS', '30'))
HEARTBEAT_INTERVAL_SECONDS = float(os.getenv('HEARTBEAT_INTERVAL_SECONDS', '60'))
RUN_TIMEOUT_SECONDS = float(os.getenv('RUN_TIMEOUT_SECONDS', '7200'))
MAX_CONCURRENT_RUNS = int(os.getenv('MAX_CONCURRENT_RUNS', '5'))
STALE_THRESHOLD_MINUTES = 5
MAX_EVENTS_PER_BATCH = 50
MAX_RETRIES_DEFAULT = 3
# Terminal conversation statuses
TERMINAL_STATUSES = frozenset({'STOPPED', 'ERROR', 'COMPLETED', 'CANCELLED'})
# Shutdown flag — set by signal handlers
_shutdown_event: asyncio.Event | None = None
# Background task tracking for graceful shutdown
_pending_tasks: set[asyncio.Task] = set()
def utc_now() -> datetime:
return datetime.now(timezone.utc)
def get_shutdown_event() -> asyncio.Event:
global _shutdown_event
if _shutdown_event is None:
_shutdown_event = asyncio.Event()
return _shutdown_event
def should_continue() -> bool:
return not get_shutdown_event().is_set()
def request_shutdown() -> None:
get_shutdown_event().set()
# ---------------------------------------------------------------------------
# Phase 1: Process inbox (event matching)
# ---------------------------------------------------------------------------
async def find_matching_automations(
session: AsyncSession, event: AutomationEvent
) -> list[Automation]:
"""Find automations that match the given event.
Phase 1 supports cron and manual triggers only — both carry
``automation_id`` in the event payload.
"""
source_type = event.source_type
payload = event.payload
if payload is None:
logger.error('Event %s has None payload — possible data corruption', event.id)
return []
if source_type in ('cron', 'manual'):
automation_id = payload.get('automation_id')
if not automation_id:
logger.warning(
'Event %s (source=%s) missing automation_id in payload',
event.id,
source_type,
)
return []
result = await session.execute(
select(Automation).where(
Automation.id == automation_id,
Automation.enabled.is_(True),
)
)
automation = result.scalar_one_or_none()
return [automation] if automation else []
logger.debug('Unhandled event source_type=%s for event %s', source_type, event.id)
return []
async def process_new_events(session: AsyncSession) -> int:
"""Claim NEW events from inbox, match to automations, create runs.
Returns the number of events processed.
"""
result = await session.execute(
select(AutomationEvent)
.where(AutomationEvent.status == 'NEW')
.order_by(AutomationEvent.created_at)
.limit(MAX_EVENTS_PER_BATCH)
.with_for_update(skip_locked=True)
)
events = list(result.scalars())
processed = 0
for event in events:
try:
automations = await find_matching_automations(session, event)
if not automations:
event.status = 'NO_MATCH'
event.processed_at = utc_now()
else:
for automation in automations:
run = AutomationRun(
id=uuid4().hex,
automation_id=automation.id,
event_id=event.id,
status='PENDING',
event_payload=event.payload,
)
session.add(run)
event.status = 'PROCESSED'
event.processed_at = utc_now()
processed += 1
except Exception as e:
logger.exception('Error processing event %s', event.id)
event.status = 'ERROR'
event.error_detail = f'Failed during event matching: {type(e).__name__}: {e}'
event.processed_at = utc_now()
if processed:
await session.commit()
logger.info('Processed %d events', processed)
return processed
# ---------------------------------------------------------------------------
# Phase 2: Claim and execute runs
# ---------------------------------------------------------------------------
async def resolve_user_api_key(session: AsyncSession, user_id: str) -> str | None:
"""Look up a user's API key from the api_keys table.
Returns the first active key found, or None.
"""
from storage.api_key import ApiKey
result = await session.execute(
select(ApiKey.key).where(ApiKey.user_id == user_id).limit(1)
)
row = result.scalar_one_or_none()
return row
async def download_automation_file(file_store_key: str) -> bytes:
"""Download the automation .py file from object storage."""
try:
from openhands.server.shared import file_store
except ImportError as exc:
raise RuntimeError(
'file_store is not available — ensure the enterprise server '
'has been initialised before calling download_automation_file'
) from exc
content = file_store.read(file_store_key)
if isinstance(content, str):
return content.encode('utf-8')
return content
def is_terminal(conversation: dict) -> bool:
"""Check if a conversation has reached a terminal status."""
status = (conversation.get('status') or '').upper()
return status in TERMINAL_STATUSES
async def _prepare_run(
run: AutomationRun,
automation: Automation,
session_factory: object,
) -> tuple[str, bytes]:
"""Resolve the user's API key and download the automation file.
Returns:
(api_key, automation_file) tuple ready for submission.
Raises:
ValueError: If no API key is found.
RuntimeError: If file_store is unavailable.
"""
async with session_factory() as key_session:
api_key = await resolve_user_api_key(key_session, automation.user_id)
if not api_key:
raise ValueError(f'No API key found for user {automation.user_id}')
automation_file = await download_automation_file(automation.file_store_key)
return api_key, automation_file
async def _monitor_conversation(
run: AutomationRun,
conversation_id: str,
api_client: OpenHandsAPIClient,
api_key: str,
session_factory: object,
) -> bool:
"""Monitor a conversation until completion or timeout.
Returns True if completed successfully, False if shutdown requested.
Raises:
TimeoutError: If the run exceeds RUN_TIMEOUT_SECONDS.
"""
start_time = utc_now()
while should_continue():
elapsed = (utc_now() - start_time).total_seconds()
if elapsed > RUN_TIMEOUT_SECONDS:
raise TimeoutError(f'Run exceeded {RUN_TIMEOUT_SECONDS}s timeout')
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
# Update heartbeat
async with session_factory() as session:
run_obj = await session.get(AutomationRun, run.id)
if run_obj:
run_obj.heartbeat_at = utc_now()
await session.commit()
# Check conversation status
conversation = (
await api_client.get_conversation(api_key, conversation_id) or {}
)
if is_terminal(conversation):
return True
return False # shutdown requested
async def _submit_and_monitor(
run: AutomationRun,
api_key: str,
automation_file: bytes,
automation: Automation,
api_client: OpenHandsAPIClient,
session_factory: object,
) -> None:
"""Submit the automation to the V1 API and monitor until completion.
Updates the run's conversation_id, sends heartbeats, and marks the
final status when the conversation reaches a terminal state.
"""
conversation = await api_client.start_conversation(
api_key=api_key,
automation_file=automation_file,
title=f'Automation: {automation.name}',
event_payload=run.event_payload,
)
conversation_id = conversation.get('app_conversation_id') or conversation.get(
'conversation_id'
)
# Persist conversation ID
async with session_factory() as update_session:
run_obj = await update_session.get(AutomationRun, run.id)
if run_obj:
run_obj.conversation_id = conversation_id
await update_session.commit()
# Monitor with heartbeats
completed = await _monitor_conversation(
run, conversation_id, api_client, api_key, session_factory
)
# Update final status
async with session_factory() as final_session:
run_obj = await final_session.get(AutomationRun, run.id)
if run_obj:
if not completed:
# Leave as RUNNING — stale recovery will handle it if needed.
# The conversation may still be running on the API side.
logger.info(
'Run %s left as RUNNING due to executor shutdown', run.id
)
else:
run_obj.status = 'COMPLETED'
run_obj.completed_at = utc_now()
logger.info('Run %s completed successfully', run.id)
await final_session.commit()
async def execute_run(
run: AutomationRun,
automation: Automation,
api_client: OpenHandsAPIClient,
session_factory: object,
) -> None:
"""Execute a single automation run end-to-end.
Orchestrates preparation (API key + file download) and submission/monitoring.
On failure, marks the run for retry or dead-letter.
"""
try:
api_key, automation_file = await _prepare_run(
run, automation, session_factory
)
await _submit_and_monitor(
run, api_key, automation_file, automation, api_client, session_factory
)
except Exception as e:
logger.exception('Run %s failed: %s', run.id, e)
await _mark_run_failed(run, str(e), session_factory)
async def _mark_run_failed(
run: AutomationRun, error: str, session_factory: object
) -> None:
"""Mark a run as FAILED or return to PENDING for retry."""
async with session_factory() as session:
run_obj = await session.get(AutomationRun, run.id)
if not run_obj:
return
run_obj.retry_count = (run_obj.retry_count or 0) + 1
run_obj.error_detail = error
if run_obj.retry_count >= (run_obj.max_retries or MAX_RETRIES_DEFAULT):
run_obj.status = 'DEAD_LETTER'
run_obj.completed_at = utc_now()
logger.error(
'Run %s moved to DEAD_LETTER after %d retries',
run.id,
run_obj.retry_count,
)
else:
run_obj.status = 'PENDING'
run_obj.claimed_by = None
backoff_seconds = 30 * (2 ** (run_obj.retry_count - 1))
run_obj.next_retry_at = utc_now() + timedelta(seconds=backoff_seconds)
logger.warning(
'Run %s returned to PENDING, retry %d/%d in %ds',
run.id,
run_obj.retry_count,
run_obj.max_retries or MAX_RETRIES_DEFAULT,
backoff_seconds,
)
await session.commit()
async def claim_and_execute_runs(
session: AsyncSession,
executor_id: str,
api_client: OpenHandsAPIClient,
session_factory: object,
) -> bool:
"""Claim a PENDING run and start executing it.
Returns True if a run was claimed, False otherwise.
"""
result = await session.execute(
select(AutomationRun)
.where(
AutomationRun.status == 'PENDING',
or_(
AutomationRun.next_retry_at.is_(None),
AutomationRun.next_retry_at <= utc_now(),
),
)
.order_by(AutomationRun.created_at)
.limit(1)
.with_for_update(skip_locked=True)
)
run = result.scalar_one_or_none()
if not run:
return False
# Claim the run
run.status = 'RUNNING'
run.claimed_by = executor_id
run.claimed_at = utc_now()
run.heartbeat_at = utc_now()
run.started_at = utc_now()
await session.commit()
# Load automation for the run
auto_result = await session.execute(
select(Automation).where(Automation.id == run.automation_id)
)
automation = auto_result.scalar_one_or_none()
if not automation:
logger.error('Automation %s not found for run %s', run.automation_id, run.id)
await _mark_run_failed(
run, f'Automation {run.automation_id} not found', session_factory
)
return True
# Execute in background (long-running) with task tracking
task = asyncio.create_task(
execute_run(run, automation, api_client, session_factory),
name=f'execute-run-{run.id}',
)
_pending_tasks.add(task)
task.add_done_callback(_pending_tasks.discard)
logger.info(
'Claimed run %s (automation=%s) by executor %s',
run.id,
run.automation_id,
executor_id,
)
return True
# ---------------------------------------------------------------------------
# Phase 3: Stale run recovery
# ---------------------------------------------------------------------------
async def recover_stale_runs(session: AsyncSession) -> int:
"""Mark RUNNING runs with expired heartbeats as PENDING for retry.
Returns the number of recovered runs.
"""
stale_threshold = utc_now() - timedelta(minutes=STALE_THRESHOLD_MINUTES)
timeout_threshold = utc_now() - timedelta(seconds=RUN_TIMEOUT_SECONDS)
# Recover stale runs (heartbeat expired)
result = await session.execute(
update(AutomationRun)
.where(
AutomationRun.status == 'RUNNING',
AutomationRun.heartbeat_at < stale_threshold,
AutomationRun.heartbeat_at >= timeout_threshold,
)
.values(
status='PENDING',
claimed_by=None,
retry_count=AutomationRun.retry_count + 1,
next_retry_at=utc_now() + timedelta(seconds=30),
)
.returning(AutomationRun.id)
)
recovered_rows = result.fetchall()
# Mark truly timed-out runs as DEAD_LETTER
timeout_result = await session.execute(
update(AutomationRun)
.where(
AutomationRun.status == 'RUNNING',
AutomationRun.heartbeat_at < timeout_threshold,
)
.values(
status='DEAD_LETTER',
error_detail='Run exceeded timeout',
completed_at=utc_now(),
)
.returning(AutomationRun.id)
)
timed_out_rows = timeout_result.fetchall()
await session.commit()
recovered_count = len(recovered_rows)
timed_out_count = len(timed_out_rows)
if recovered_count:
logger.warning('Recovered %d stale automation runs', recovered_count)
if timed_out_count:
logger.warning(
'Marked %d automation runs as DEAD_LETTER (timeout)', timed_out_count
)
return recovered_count + timed_out_count
# ---------------------------------------------------------------------------
# Main executor loop
# ---------------------------------------------------------------------------
async def executor_main(session_factory: object | None = None) -> None:
"""Main executor loop.
Args:
session_factory: Async context manager that yields AsyncSession instances.
If None, uses the default ``a_session_maker`` from database module.
"""
if session_factory is None:
from storage.database import a_session_maker
session_factory = a_session_maker
executor_id = f'executor-{socket.gethostname()}-{os.getpid()}'
api_url = os.getenv('OPENHANDS_API_URL', 'http://openhands-service:3000')
api_client = OpenHandsAPIClient(base_url=api_url)
logger.info(
'Automation executor %s starting (api_url=%s, poll=%ss, heartbeat=%ss)',
executor_id,
api_url,
POLL_INTERVAL_SECONDS,
HEARTBEAT_INTERVAL_SECONDS,
)
try:
while should_continue():
try:
async with session_factory() as session:
await process_new_events(session)
async with session_factory() as session:
await claim_and_execute_runs(
session, executor_id, api_client, session_factory
)
async with session_factory() as session:
await recover_stale_runs(session)
except Exception:
logger.exception('Error in executor main loop iteration')
# Wait for next poll interval (or early wakeup on shutdown)
try:
await asyncio.wait_for(
get_shutdown_event().wait(),
timeout=POLL_INTERVAL_SECONDS,
)
except asyncio.TimeoutError:
pass # Normal — poll interval elapsed
finally:
if _pending_tasks:
logger.info(
'Waiting for %d running tasks to complete...', len(_pending_tasks)
)
await asyncio.gather(*_pending_tasks, return_exceptions=True)
await api_client.close()
logger.info('Automation executor %s shut down', executor_id)

View File

@@ -1,93 +0,0 @@
"""HTTP client for the main OpenHands V1 API (internal cluster calls).
Used by the automation executor to create and monitor conversations
in the main OpenHands server.
"""
import base64
import logging
import httpx
logger = logging.getLogger('saas.automation.api_client')
def _raise_with_body(resp: httpx.Response) -> None:
"""Call raise_for_status, enriching the error with the response body."""
try:
resp.raise_for_status()
except httpx.HTTPStatusError as e:
error_body = resp.text[:500] if resp.text else 'no response body'
raise httpx.HTTPStatusError(
f'{e.args[0]} — Response: {error_body}',
request=e.request,
response=e.response,
) from e
class OpenHandsAPIClient:
"""Async HTTP client for the OpenHands V1 API."""
def __init__(self, base_url: str = 'http://openhands-service:3000'):
self.base_url = base_url.rstrip('/')
self.client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0)
async def start_conversation(
self,
api_key: str,
automation_file: bytes,
title: str,
event_payload: dict | None = None,
) -> dict:
"""Submit an SDK script for sandboxed execution via V1 API.
Args:
api_key: User's API key for authentication.
automation_file: Raw bytes of the .py automation script.
title: Display title for the conversation.
event_payload: Optional trigger event data (injected as env var).
Returns:
Parsed JSON response containing conversation details.
Raises:
httpx.HTTPStatusError: If the API returns a non-2xx status.
"""
resp = await self.client.post(
'/api/v1/app-conversations',
json={
'automation_file': base64.b64encode(automation_file).decode(),
'trigger': 'automation',
'title': title,
'event_payload': event_payload,
},
headers={'Authorization': f'Bearer {api_key}'},
)
_raise_with_body(resp)
return resp.json()
async def get_conversation(self, api_key: str, conversation_id: str) -> dict | None:
"""Get conversation status.
Args:
api_key: User's API key for authentication.
conversation_id: The conversation ID to look up.
Returns:
Conversation data dict, or None if not found.
Raises:
httpx.HTTPStatusError: If the API returns a non-2xx status.
"""
resp = await self.client.get(
'/api/v1/app-conversations',
params={'ids': [conversation_id]},
headers={'Authorization': f'Bearer {api_key}'},
)
_raise_with_body(resp)
conversations = resp.json()
return conversations[0] if conversations else None
async def close(self) -> None:
"""Close the underlying HTTP client."""
await self.client.aclose()

View File

@@ -1,77 +0,0 @@
"""SQLAlchemy models for automations and automation runs.
Stub for Task 1 (Data Foundation). These models will be replaced when Task 1
is merged into automations-phase1.
"""
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
Text,
text,
)
from sqlalchemy.orm import relationship
from sqlalchemy.types import JSON
from storage.base import Base
class Automation(Base):
__tablename__ = 'automations'
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False, index=True)
org_id = Column(String, nullable=True, index=True)
name = Column(String, nullable=False)
enabled = Column(Boolean, nullable=False, server_default=text('true'))
config = Column(JSON, nullable=False)
trigger_type = Column(String, nullable=False)
file_store_key = Column(String, nullable=False)
last_triggered_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
)
updated_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
)
runs = relationship('AutomationRun', back_populates='automation')
class AutomationRun(Base):
__tablename__ = 'automation_runs'
id = Column(String, primary_key=True)
automation_id = Column(
String, ForeignKey('automations.id', ondelete='CASCADE'), nullable=False
)
event_id = Column(Integer, ForeignKey('automation_events.id'), nullable=True)
conversation_id = Column(String, nullable=True)
status = Column(String, nullable=False, server_default=text("'PENDING'"))
claimed_by = Column(String, nullable=True)
claimed_at = Column(DateTime(timezone=True), nullable=True)
heartbeat_at = Column(DateTime(timezone=True), nullable=True)
retry_count = Column(Integer, nullable=False, server_default=text('0'))
max_retries = Column(Integer, nullable=False, server_default=text('3'))
next_retry_at = Column(DateTime(timezone=True), nullable=True)
event_payload = Column(JSON, nullable=True)
error_detail = Column(Text, nullable=True)
started_at = Column(DateTime(timezone=True), nullable=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
)
automation = relationship('Automation', back_populates='runs')

View File

@@ -1,27 +0,0 @@
"""SQLAlchemy model for automation events (the inbox).
Stub for Task 1 (Data Foundation). This model will be replaced when Task 1
is merged into automations-phase1.
"""
from sqlalchemy import Column, DateTime, Integer, String, Text, text
from sqlalchemy.types import JSON
from storage.base import Base
class AutomationEvent(Base):
__tablename__ = 'automation_events'
id = Column(Integer, primary_key=True, autoincrement=True)
source_type = Column(String, nullable=False)
payload = Column(JSON, nullable=False)
metadata_ = Column('metadata', JSON, nullable=True)
dedup_key = Column(String, nullable=False, unique=True)
status = Column(String, nullable=False, server_default=text("'NEW'"))
error_detail = Column(Text, nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
)
processed_at = Column(DateTime(timezone=True), nullable=True)

View File

@@ -28,6 +28,9 @@ class OrgMemberStore:
role_id: int,
llm_api_key: str,
status: Optional[str] = None,
llm_model: Optional[str] = None,
llm_base_url: Optional[str] = None,
max_iterations: Optional[int] = None,
) -> OrgMember:
"""Add a user to an organization with a specific role."""
async with a_session_maker() as session:
@@ -37,6 +40,9 @@ class OrgMemberStore:
role_id=role_id,
llm_api_key=llm_api_key,
status=status,
llm_model=llm_model,
llm_base_url=llm_base_url,
max_iterations=max_iterations,
)
session.add(org_member)
await session.commit()

View File

@@ -187,6 +187,18 @@ class SaasSettingsStore(SettingsStore):
if hasattr(model, key):
setattr(model, key, value)
# Map Settings fields to Org fields with 'default_' prefix
# The generic loop above doesn't update these because Org uses
# 'default_llm_model' not 'llm_model', etc.
# Use exclude_unset to only update explicitly-set fields (allows clearing with null)
settings_data = item.model_dump(exclude_unset=True)
if 'llm_model' in settings_data:
org.default_llm_model = settings_data['llm_model']
if 'llm_base_url' in settings_data:
org.default_llm_base_url = settings_data['llm_base_url']
if 'max_iterations' in settings_data:
org.default_max_iterations = settings_data['max_iterations']
# Propagate LLM settings to all org members
# This ensures all members see the same LLM configuration when an admin saves
# Note: Concurrent saves by multiple admins will result in last-write-wins.

View File

@@ -1,68 +0,0 @@
"""Shared fixtures for services tests.
Note: We pre-load ``storage`` as a namespace package to avoid the heavy
``storage/__init__.py`` that imports the entire enterprise model graph.
This must happen *before* any ``from storage.…`` import.
"""
import contextlib
import sys
import types
# Prevent storage/__init__.py from loading the full model graph.
# We only need the lightweight automation models for these tests.
if 'storage' not in sys.modules:
import pathlib
_storage_dir = str(pathlib.Path(__file__).resolve().parents[3] / 'storage')
_mod = types.ModuleType('storage')
_mod.__path__ = [_storage_dir]
sys.modules['storage'] = _mod
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.automation import Automation, AutomationRun # noqa: F401
from storage.automation_event import AutomationEvent # noqa: F401
from storage.base import Base
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session_factory(async_engine):
"""Create an async session factory that yields context-managed sessions."""
factory = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
@contextlib.asynccontextmanager
async def _session_ctx():
async with factory() as session:
yield session
return _session_ctx
@pytest.fixture
async def async_session(async_session_factory):
"""Create a single async session for testing."""
async with async_session_factory() as session:
yield session

View File

@@ -1,624 +0,0 @@
"""Tests for the automation executor.
Uses real SQLite database operations for event processing, run claiming,
and stale run recovery. HTTP calls to the V1 API are mocked.
"""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from services.automation_executor import (
_mark_run_failed,
claim_and_execute_runs,
find_matching_automations,
is_terminal,
process_new_events,
recover_stale_runs,
utc_now,
)
from sqlalchemy import select
from storage.automation import Automation, AutomationRun
from storage.automation_event import AutomationEvent
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_automation(
automation_id: str = 'auto-1',
user_id: str = 'user-1',
enabled: bool = True,
trigger_type: str = 'cron',
name: str = 'Test Automation',
) -> Automation:
return Automation(
id=automation_id,
user_id=user_id,
org_id='org-1',
name=name,
enabled=enabled,
config={'triggers': {'cron': {'schedule': '0 9 * * 5'}}},
trigger_type=trigger_type,
file_store_key=f'automations/{automation_id}/script.py',
)
def make_event(
source_type: str = 'cron',
payload: dict | None = None,
status: str = 'NEW',
dedup_key: str | None = None,
) -> AutomationEvent:
return AutomationEvent(
source_type=source_type,
payload=payload or {'automation_id': 'auto-1'},
dedup_key=dedup_key or f'dedup-{uuid4().hex[:8]}',
status=status,
created_at=utc_now(),
)
def make_run(
run_id: str | None = None,
automation_id: str = 'auto-1',
status: str = 'PENDING',
claimed_by: str | None = None,
heartbeat_at: datetime | None = None,
retry_count: int = 0,
max_retries: int = 3,
next_retry_at: datetime | None = None,
) -> AutomationRun:
return AutomationRun(
id=run_id or uuid4().hex,
automation_id=automation_id,
status=status,
claimed_by=claimed_by,
heartbeat_at=heartbeat_at,
retry_count=retry_count,
max_retries=max_retries,
next_retry_at=next_retry_at,
event_payload={'automation_id': automation_id},
created_at=utc_now(),
)
# ---------------------------------------------------------------------------
# find_matching_automations
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_find_matching_automations_cron_event(async_session):
"""Cron events match by automation_id in payload."""
automation = make_automation()
async_session.add(automation)
await async_session.commit()
event = make_event(source_type='cron', payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 1
assert result[0].id == 'auto-1'
@pytest.mark.asyncio
async def test_find_matching_automations_manual_event(async_session):
"""Manual events also match by automation_id in payload."""
automation = make_automation()
async_session.add(automation)
await async_session.commit()
event = make_event(source_type='manual', payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 1
assert result[0].id == 'auto-1'
@pytest.mark.asyncio
async def test_find_matching_automations_disabled_automation(async_session):
"""Disabled automations are not matched."""
automation = make_automation(enabled=False)
async_session.add(automation)
await async_session.commit()
event = make_event(payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
@pytest.mark.asyncio
async def test_find_matching_automations_missing_automation_id(async_session):
"""Events without automation_id in payload return empty list."""
event = make_event(payload={'something_else': 'value'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
@pytest.mark.asyncio
async def test_find_matching_automations_nonexistent_automation(async_session):
"""Events referencing a non-existent automation return empty list."""
event = make_event(payload={'automation_id': 'nonexistent'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
@pytest.mark.asyncio
async def test_find_matching_automations_unknown_source_type(async_session):
"""Unknown source types return empty list."""
event = make_event(source_type='unknown', payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
# ---------------------------------------------------------------------------
# process_new_events
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_process_new_events_creates_runs(async_session):
"""Processing NEW events creates PENDING runs and marks events PROCESSED."""
automation = make_automation()
event = make_event(payload={'automation_id': 'auto-1'})
async_session.add_all([automation, event])
await async_session.commit()
count = await process_new_events(async_session)
assert count == 1
# Event should be PROCESSED
await async_session.refresh(event)
assert event.status == 'PROCESSED'
assert event.processed_at is not None
# A run should have been created
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 1
assert runs[0].automation_id == 'auto-1'
assert runs[0].status == 'PENDING'
assert runs[0].event_payload == {'automation_id': 'auto-1'}
@pytest.mark.asyncio
async def test_process_new_events_no_match(async_session):
"""Events with no matching automation are marked NO_MATCH."""
event = make_event(payload={'automation_id': 'nonexistent'})
async_session.add(event)
await async_session.commit()
count = await process_new_events(async_session)
assert count == 1
await async_session.refresh(event)
assert event.status == 'NO_MATCH'
assert event.processed_at is not None
# No runs created
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 0
@pytest.mark.asyncio
async def test_process_new_events_skips_processed(async_session):
"""Already processed events are not re-processed."""
event = make_event(status='PROCESSED')
async_session.add(event)
await async_session.commit()
count = await process_new_events(async_session)
assert count == 0
@pytest.mark.asyncio
async def test_process_new_events_multiple_events(async_session):
"""Multiple NEW events are processed in one batch."""
auto1 = make_automation(automation_id='auto-1')
auto2 = make_automation(automation_id='auto-2', name='Auto 2')
event1 = make_event(payload={'automation_id': 'auto-1'}, dedup_key='dedup-1')
event2 = make_event(payload={'automation_id': 'auto-2'}, dedup_key='dedup-2')
event3 = make_event(payload={'automation_id': 'nonexistent'}, dedup_key='dedup-3')
async_session.add_all([auto1, auto2, event1, event2, event3])
await async_session.commit()
count = await process_new_events(async_session)
assert count == 3
# Two runs created (for auto-1 and auto-2), none for nonexistent
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 2
await async_session.refresh(event1)
await async_session.refresh(event2)
await async_session.refresh(event3)
assert event1.status == 'PROCESSED'
assert event2.status == 'PROCESSED'
assert event3.status == 'NO_MATCH'
# ---------------------------------------------------------------------------
# claim_and_execute_runs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_claim_and_execute_runs_claims_pending(
async_session, async_session_factory
):
"""Claims a PENDING run and transitions to RUNNING."""
automation = make_automation()
run = make_run(run_id='run-1')
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is True
await async_session.refresh(run)
assert run.status == 'RUNNING'
assert run.claimed_by == 'executor-test-1'
assert run.claimed_at is not None
assert run.heartbeat_at is not None
assert run.started_at is not None
@pytest.mark.asyncio
async def test_claim_and_execute_runs_no_pending(async_session, async_session_factory):
"""Returns False when no PENDING runs exist."""
api_client = AsyncMock()
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is False
@pytest.mark.asyncio
async def test_claim_and_execute_runs_respects_next_retry_at(
async_session, async_session_factory
):
"""Runs with future next_retry_at are not claimed."""
automation = make_automation()
run = make_run(
run_id='run-retry',
next_retry_at=utc_now() + timedelta(hours=1),
)
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is False
@pytest.mark.asyncio
async def test_claim_and_execute_runs_past_retry_at(
async_session, async_session_factory
):
"""Runs with past next_retry_at are claimable."""
automation = make_automation()
run = make_run(
run_id='run-retry-past',
next_retry_at=utc_now() - timedelta(minutes=5),
)
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is True
@pytest.mark.asyncio
async def test_claim_skips_running_runs(async_session, async_session_factory):
"""RUNNING runs are not claimed."""
automation = make_automation()
run = make_run(run_id='run-running', status='RUNNING', claimed_by='other-executor')
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is False
# ---------------------------------------------------------------------------
# recover_stale_runs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_recover_stale_runs_recovers_stale(async_session):
"""RUNNING runs with expired heartbeats are recovered to PENDING."""
automation = make_automation()
stale_run = make_run(
run_id='stale-1',
status='RUNNING',
claimed_by='crashed-executor',
heartbeat_at=utc_now() - timedelta(minutes=10),
retry_count=0,
)
async_session.add_all([automation, stale_run])
await async_session.commit()
count = await recover_stale_runs(async_session)
assert count >= 1
await async_session.refresh(stale_run)
assert stale_run.status == 'PENDING'
assert stale_run.claimed_by is None
assert stale_run.retry_count == 1
assert stale_run.next_retry_at is not None
@pytest.mark.asyncio
async def test_recover_stale_runs_ignores_fresh(async_session):
"""RUNNING runs with recent heartbeats are not recovered."""
automation = make_automation()
fresh_run = make_run(
run_id='fresh-1',
status='RUNNING',
claimed_by='active-executor',
heartbeat_at=utc_now() - timedelta(seconds=30),
)
async_session.add_all([automation, fresh_run])
await async_session.commit()
count = await recover_stale_runs(async_session)
assert count == 0
await async_session.refresh(fresh_run)
assert fresh_run.status == 'RUNNING'
assert fresh_run.claimed_by == 'active-executor'
@pytest.mark.asyncio
async def test_recover_stale_runs_ignores_pending(async_session):
"""PENDING runs are not affected by recovery."""
automation = make_automation()
pending_run = make_run(run_id='pending-1', status='PENDING')
async_session.add_all([automation, pending_run])
await async_session.commit()
count = await recover_stale_runs(async_session)
assert count == 0
await async_session.refresh(pending_run)
assert pending_run.status == 'PENDING'
@pytest.mark.asyncio
async def test_recover_stale_runs_increments_retry_count(async_session):
"""Recovery increments the retry_count."""
automation = make_automation()
stale_run = make_run(
run_id='stale-retry',
status='RUNNING',
claimed_by='old-executor',
heartbeat_at=utc_now() - timedelta(minutes=10),
retry_count=2,
)
async_session.add_all([automation, stale_run])
await async_session.commit()
await recover_stale_runs(async_session)
await async_session.refresh(stale_run)
assert stale_run.retry_count == 3
# ---------------------------------------------------------------------------
# _mark_run_failed (error handling)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mark_run_failed_retries(async_session_factory):
"""Failed runs with retries left return to PENDING."""
async with async_session_factory() as session:
automation = make_automation()
run = make_run(run_id='fail-retry', retry_count=0, max_retries=3)
session.add_all([automation, run])
await session.commit()
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-retry')
await _mark_run_failed(run_obj, 'API error', async_session_factory)
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-retry')
assert run_obj.status == 'PENDING'
assert run_obj.retry_count == 1
assert run_obj.error_detail == 'API error'
assert run_obj.next_retry_at is not None
assert run_obj.claimed_by is None
@pytest.mark.asyncio
async def test_mark_run_failed_dead_letter(async_session_factory):
"""Failed runs that exceed max_retries go to DEAD_LETTER."""
async with async_session_factory() as session:
automation = make_automation()
run = make_run(run_id='fail-dead', retry_count=2, max_retries=3)
session.add_all([automation, run])
await session.commit()
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-dead')
await _mark_run_failed(run_obj, 'Final failure', async_session_factory)
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-dead')
assert run_obj.status == 'DEAD_LETTER'
assert run_obj.retry_count == 3
assert run_obj.error_detail == 'Final failure'
assert run_obj.completed_at is not None
# ---------------------------------------------------------------------------
# is_terminal
# ---------------------------------------------------------------------------
def test_is_terminal_stopped():
assert is_terminal({'status': 'STOPPED'}) is True
def test_is_terminal_error():
assert is_terminal({'status': 'ERROR'}) is True
def test_is_terminal_completed():
assert is_terminal({'status': 'COMPLETED'}) is True
def test_is_terminal_cancelled():
assert is_terminal({'status': 'CANCELLED'}) is True
def test_is_terminal_running():
assert is_terminal({'status': 'RUNNING'}) is False
def test_is_terminal_empty():
assert is_terminal({}) is False
def test_is_terminal_case_insensitive():
assert is_terminal({'status': 'stopped'}) is True
assert is_terminal({'status': 'Completed'}) is True
# ---------------------------------------------------------------------------
# find_matching_automations — None payload
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_find_matching_automations_none_payload(async_session):
"""Events with None payload return empty list (data corruption guard)."""
event = make_event(source_type='cron')
event.payload = None
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert result == []
# ---------------------------------------------------------------------------
# Integration: event → run creation → claim
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_integration_event_to_run_to_claim(
async_session_factory,
):
"""Full flow: create event + automation → process_new_events → claim_and_execute_runs.
Uses a real SQLite database; only the external API client is mocked.
"""
# 1. Seed an automation and a NEW event
async with async_session_factory() as session:
automation = make_automation(automation_id='integ-auto')
event = make_event(
source_type='cron',
payload={'automation_id': 'integ-auto'},
dedup_key='integ-dedup',
)
session.add_all([automation, event])
await session.commit()
event_id = event.id
# 2. Process inbox — should match and create a PENDING run
async with async_session_factory() as session:
processed = await process_new_events(session)
assert processed == 1
# Verify event is PROCESSED and run was created
async with async_session_factory() as session:
evt = await session.get(AutomationEvent, event_id)
assert evt.status == 'PROCESSED'
runs = (await session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 1
run = runs[0]
assert run.automation_id == 'integ-auto'
assert run.status == 'PENDING'
assert run.event_payload == {'automation_id': 'integ-auto'}
# 3. Claim the run — mock execute_run to avoid real API calls
api_client = AsyncMock()
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
async with async_session_factory() as session:
claimed = await claim_and_execute_runs(
session, 'executor-integ', api_client, async_session_factory
)
assert claimed is True
# 4. Verify the run moved to RUNNING with correct executor
async with async_session_factory() as session:
runs = (await session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 1
run = runs[0]
assert run.status == 'RUNNING'
assert run.claimed_by == 'executor-integ'
assert run.started_at is not None
assert run.heartbeat_at is not None

View File

@@ -1,185 +0,0 @@
"""Tests for OpenHandsAPIClient with mocked HTTP responses."""
import base64
import httpx
import pytest
from services.openhands_api_client import OpenHandsAPIClient
@pytest.fixture
def api_client():
client = OpenHandsAPIClient(base_url='http://test-server:3000')
yield client
# close handled in tests that need it
# ---------------------------------------------------------------------------
# start_conversation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_start_conversation_sends_correct_request(api_client, respx_mock):
"""start_conversation sends properly formatted POST with auth header."""
automation_file = b'print("hello")'
expected_b64 = base64.b64encode(automation_file).decode()
route = respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(
200,
json={
'app_conversation_id': 'conv-123',
'status': 'RUNNING',
},
)
)
result = await api_client.start_conversation(
api_key='sk-oh-test123',
automation_file=automation_file,
title='Test Automation',
event_payload={'automation_id': 'auto-1'},
)
assert route.called
request = route.calls[0].request
assert request.headers['Authorization'] == 'Bearer sk-oh-test123'
import json
body = json.loads(request.content)
assert body['automation_file'] == expected_b64
assert body['trigger'] == 'automation'
assert body['title'] == 'Test Automation'
assert body['event_payload'] == {'automation_id': 'auto-1'}
assert result == {'app_conversation_id': 'conv-123', 'status': 'RUNNING'}
@pytest.mark.asyncio
async def test_start_conversation_without_event_payload(api_client, respx_mock):
"""start_conversation works with event_payload=None."""
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(200, json={'app_conversation_id': 'conv-456'})
)
result = await api_client.start_conversation(
api_key='sk-oh-test',
automation_file=b'code',
title='Test',
event_payload=None,
)
assert result['app_conversation_id'] == 'conv-456'
@pytest.mark.asyncio
async def test_start_conversation_http_error(api_client, respx_mock):
"""start_conversation raises on HTTP errors."""
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(500, json={'error': 'Internal Server Error'})
)
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await api_client.start_conversation(
api_key='sk-oh-test',
automation_file=b'code',
title='Test',
)
assert exc_info.value.response.status_code == 500
@pytest.mark.asyncio
async def test_start_conversation_auth_error(api_client, respx_mock):
"""start_conversation raises on 401 Unauthorized."""
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(401, json={'error': 'Unauthorized'})
)
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await api_client.start_conversation(
api_key='bad-key',
automation_file=b'code',
title='Test',
)
assert exc_info.value.response.status_code == 401
# ---------------------------------------------------------------------------
# get_conversation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_conversation_returns_data(api_client, respx_mock):
"""get_conversation returns the first conversation from the list."""
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(
200,
json=[
{
'conversation_id': 'conv-123',
'status': 'RUNNING',
'title': 'My Automation',
}
],
)
)
result = await api_client.get_conversation('sk-oh-test', 'conv-123')
assert result is not None
assert result['conversation_id'] == 'conv-123'
assert result['status'] == 'RUNNING'
@pytest.mark.asyncio
async def test_get_conversation_returns_none_when_empty(api_client, respx_mock):
"""get_conversation returns None when API returns empty list."""
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(200, json=[])
)
result = await api_client.get_conversation('sk-oh-test', 'nonexistent')
assert result is None
@pytest.mark.asyncio
async def test_get_conversation_sends_auth_header(api_client, respx_mock):
"""get_conversation sends the correct authorization header."""
route = respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(200, json=[])
)
await api_client.get_conversation('sk-oh-mykey', 'conv-1')
assert route.called
request = route.calls[0].request
assert request.headers['Authorization'] == 'Bearer sk-oh-mykey'
@pytest.mark.asyncio
async def test_get_conversation_http_error(api_client, respx_mock):
"""get_conversation raises on HTTP errors."""
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(503, text='Service Unavailable')
)
with pytest.raises(httpx.HTTPStatusError):
await api_client.get_conversation('sk-oh-test', 'conv-1')
# ---------------------------------------------------------------------------
# close
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_close(api_client):
"""close() shuts down the HTTP client without errors."""
await api_client.close()

View File

@@ -10,6 +10,9 @@ from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
import pytest
from server.utils.saas_app_conversation_info_injector import (
SaasSQLAppConversationInfoService,
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
@@ -17,9 +20,6 @@ from storage.base import Base
from storage.org import Org
from storage.user import User
from enterprise.server.utils.saas_app_conversation_info_injector import (
SaasSQLAppConversationInfoService,
)
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)

View File

@@ -11,7 +11,6 @@ from server.auth.auth_error import AuthError
from server.auth.saas_user_auth import SaasUserAuth
from server.auth.user.user_authorizer import UserAuthorizationResponse, UserAuthorizer
from server.routes.auth import (
_extract_recaptcha_state,
accept_tos,
authenticate,
keycloak_callback,
@@ -55,11 +54,12 @@ def mock_response():
def test_set_response_cookie(mock_response, mock_request):
"""Test setting the auth cookie on a response."""
with patch('server.routes.auth.config') as mock_config:
with (
patch('server.routes.auth.config') as mock_config,
patch('server.utils.url_utils.get_global_config') as get_global_config,
):
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
# Configure mock_request.url.hostname
mock_request.url.hostname = 'example.com'
get_global_config.return_value = MagicMock(web_url='https://example.com')
set_response_cookie(
request=mock_request,
@@ -1036,79 +1036,6 @@ async def test_keycloak_callback_no_email_in_user_info(
mock_token_manager.check_duplicate_base_email.assert_not_called()
class TestExtractRecaptchaState:
"""Tests for _extract_recaptcha_state() helper function."""
def test_should_extract_redirect_url_and_token_from_new_json_format(self):
"""Test extraction from new base64-encoded JSON format."""
# Arrange
state_data = {
'redirect_url': 'https://example.com',
'recaptcha_token': 'test-token',
}
encoded_state = base64.urlsafe_b64encode(
json.dumps(state_data).encode()
).decode()
# Act
redirect_url, token = _extract_recaptcha_state(encoded_state)
# Assert
assert redirect_url == 'https://example.com'
assert token == 'test-token'
def test_should_handle_old_format_plain_redirect_url(self):
"""Test handling of old format (plain redirect URL string)."""
# Arrange
state = 'https://example.com'
# Act
redirect_url, token = _extract_recaptcha_state(state)
# Assert
assert redirect_url == 'https://example.com'
assert token is None
def test_should_handle_none_state(self):
"""Test handling of None state."""
# Arrange
state = None
# Act
redirect_url, token = _extract_recaptcha_state(state)
# Assert
assert redirect_url == ''
assert token is None
def test_should_handle_invalid_base64_gracefully(self):
"""Test handling of invalid base64/JSON (fallback to old format)."""
# Arrange
state = 'not-valid-base64!!!'
# Act
redirect_url, token = _extract_recaptcha_state(state)
# Assert
assert redirect_url == state
assert token is None
def test_should_handle_missing_redirect_url_in_json(self):
"""Test handling when redirect_url is missing in JSON."""
# Arrange
state_data = {'recaptcha_token': 'test-token'}
encoded_state = base64.urlsafe_b64encode(
json.dumps(state_data).encode()
).decode()
# Act
redirect_url, token = _extract_recaptcha_state(encoded_state)
# Assert
assert redirect_url == ''
assert token == 'test-token'
class TestKeycloakCallbackRecaptcha:
"""Tests for reCAPTCHA integration in keycloak_callback()."""

View File

@@ -48,7 +48,7 @@ def mock_checkout_request():
'server': ('test.com', 80),
}
)
request._base_url = URL('http://test.com/')
request._url = URL('http://test.com/')
return request
@@ -62,7 +62,7 @@ def mock_subscription_request():
'server': ('test.com', 80),
}
)
request._base_url = URL('http://test.com/')
request._url = URL('http://test.com/')
return request
@@ -264,7 +264,7 @@ async def test_create_checkout_session_success(
async def test_success_callback_session_not_found(async_session_maker):
"""Test success callback when billing session is not found."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_request._url = URL('http://test.com/')
with (
patch('server.routes.billing.a_session_maker', async_session_maker),
@@ -281,7 +281,7 @@ async def test_success_callback_stripe_incomplete(
):
"""Test success callback when Stripe session is not complete."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_request._url = URL('http://test.com/')
session_id = 'test_incomplete_session'
async with async_session_maker() as session:
@@ -319,7 +319,7 @@ async def test_success_callback_stripe_incomplete(
async def test_success_callback_success(async_session_maker, test_org, test_user):
"""Test successful payment completion and credit update."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_request._url = URL('http://test.com/')
session_id = 'test_success_session'
async with async_session_maker() as session:
@@ -391,7 +391,7 @@ async def test_success_callback_lite_llm_error(
):
"""Test handling of LiteLLM API errors during success callback."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_request._url = URL('http://test.com/')
session_id = 'test_litellm_error_session'
async with async_session_maker() as session:
@@ -445,7 +445,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback(
the database transaction rolls back.
"""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_request._url = URL('http://test.com/')
session_id = 'test_budget_rollback_session'
async with async_session_maker() as session:
@@ -502,7 +502,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback(
async def test_cancel_callback_session_not_found(async_session_maker):
"""Test cancel callback when billing session is not found."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_request._url = URL('http://test.com/')
with patch('server.routes.billing.a_session_maker', async_session_maker):
response = await cancel_callback('nonexistent_session_id', mock_request)
@@ -517,7 +517,7 @@ async def test_cancel_callback_session_not_found(async_session_maker):
async def test_cancel_callback_success(async_session_maker, test_org, test_user):
"""Test successful cancellation of billing session."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_request._url = URL('http://test.com/')
session_id = 'test_cancel_session'
async with async_session_maker() as session:
@@ -588,7 +588,7 @@ async def test_create_customer_setup_session_success():
'headers': [],
}
)
mock_request._base_url = URL('http://test.com/')
mock_request._url = URL('http://test.com/')
mock_customer_info = {'customer_id': 'mock-customer-id', 'org_id': 'mock-org-id'}
mock_session = MagicMock()
@@ -613,6 +613,6 @@ async def test_create_customer_setup_session_success():
customer='mock-customer-id',
mode='setup',
payment_method_types=['card'],
success_url='https://test.com/?setup=success',
cancel_url='https://test.com/',
success_url='https://test.com?setup=success',
cancel_url='https://test.com',
)

View File

@@ -98,6 +98,11 @@ class TestAcceptInvitationEmailValidation:
mock_keycloak_user_info = {'email': 'alice@example.com'} # Email from Keycloak
mock_org = MagicMock()
mock_org.default_llm_model = 'test-model'
mock_org.default_llm_base_url = None
mock_org.default_max_iterations = None
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
@@ -121,6 +126,10 @@ class TestAcceptInvitationEmailValidation:
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
new_callable=AsyncMock,
) as mock_get_org,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org',
new_callable=AsyncMock,
@@ -145,6 +154,7 @@ class TestAcceptInvitationEmailValidation:
mock_settings = MagicMock()
mock_settings.llm_api_key = SecretStr('test-key')
mock_create_litellm.return_value = mock_settings
mock_get_org.return_value = mock_org
mock_update_status.return_value = mock_invitation
# Act - should not raise error because Keycloak email matches
@@ -214,6 +224,11 @@ class TestAcceptInvitationEmailValidation:
mock_invitation.email = 'alice@example.com' # Lowercase in invitation
mock_org = MagicMock()
mock_org.default_llm_model = 'test-model'
mock_org.default_llm_base_url = None
mock_org.default_max_iterations = None
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
@@ -234,6 +249,10 @@ class TestAcceptInvitationEmailValidation:
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
new_callable=AsyncMock,
) as mock_get_org,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org',
new_callable=AsyncMock,
@@ -250,6 +269,7 @@ class TestAcceptInvitationEmailValidation:
mock_settings = MagicMock()
mock_settings.llm_api_key = SecretStr('test-key')
mock_create_litellm.return_value = mock_settings
mock_get_org.return_value = mock_org
mock_update_status.return_value = mock_invitation
# Act - should not raise error because emails match case-insensitively
@@ -258,6 +278,75 @@ class TestAcceptInvitationEmailValidation:
# Assert - invitation was accepted (update_invitation_status was called)
mock_update_status.assert_called_once()
@pytest.mark.asyncio
async def test_accept_invitation_inherits_org_llm_settings(self, mock_invitation):
"""Test that new members inherit the organization's LLM settings when accepting invitation."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = 'alice@example.com'
mock_org = MagicMock()
mock_org.default_llm_model = 'claude-sonnet-4'
mock_org.default_llm_base_url = 'https://api.anthropic.com'
mock_org.default_max_iterations = 100
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
new_callable=AsyncMock,
) as mock_get_member,
patch(
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
new_callable=AsyncMock,
) as mock_get_org,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org',
new_callable=AsyncMock,
) as mock_add_user,
patch(
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
new_callable=AsyncMock,
) as mock_update_status,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
mock_get_member.return_value = None
mock_settings = MagicMock()
mock_settings.llm_api_key = SecretStr('test-key')
mock_create_litellm.return_value = mock_settings
mock_get_org.return_value = mock_org
mock_update_status.return_value = mock_invitation
# Act
await OrgInvitationService.accept_invitation(token, user_id)
# Assert - verify add_user_to_org was called with org's LLM settings
mock_add_user.assert_called_once()
call_kwargs = mock_add_user.call_args.kwargs
assert call_kwargs['llm_model'] == 'claude-sonnet-4'
assert call_kwargs['llm_base_url'] == 'https://api.anthropic.com'
assert call_kwargs['max_iterations'] == 100
class TestCreateInvitationsBatch:
"""Test cases for batch invitation creation."""

View File

@@ -246,6 +246,43 @@ async def test_add_user_to_org(async_session_maker):
assert org_member.status == 'active'
@pytest.mark.asyncio
async def test_add_user_to_org_with_llm_settings(async_session_maker):
"""Test that add_user_to_org correctly sets inherited LLM settings from organization."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org-llm')
session.add(org)
await session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id)
role = Role(name='member', rank=2)
session.add_all([user, role])
await session.commit()
org_id = org.id
user_id = user.id
role_id = role.id
# Act
with patch('storage.org_member_store.a_session_maker', async_session_maker):
org_member = await OrgMemberStore.add_user_to_org(
org_id=org_id,
user_id=user_id,
role_id=role_id,
llm_api_key='test-api-key',
status='active',
llm_model='claude-sonnet-4',
llm_base_url='https://api.example.com',
max_iterations=50,
)
# Assert
assert org_member is not None
assert org_member.llm_model == 'claude-sonnet-4'
assert org_member.llm_base_url == 'https://api.example.com'
assert org_member.max_iterations == 50
@pytest.mark.asyncio
async def test_update_user_role_in_org(async_session_maker):
# Test updating user role in org

View File

@@ -396,3 +396,44 @@ async def test_store_propagates_llm_settings_to_all_org_members(
assert (
decrypted_key == 'new-shared-api-key'
), f'Expected llm_api_key to decrypt to new-shared-api-key for member {member.user_id}'
@pytest.mark.asyncio
async def test_store_updates_org_default_llm_settings(
session_maker, async_session_maker, mock_config, org_with_multiple_members_fixture
):
"""When admin saves LLM settings, org's default_llm_model/base_url/max_iterations should be updated.
This test verifies that the Org table's default settings are updated so that
new members joining later will inherit the correct LLM configuration.
"""
from sqlalchemy import select
from storage.org import Org
# Arrange
fixture = org_with_multiple_members_fixture
org_id = fixture['org_id']
admin_user_id = str(fixture['admin_user_id'])
store = SaasSettingsStore(admin_user_id, mock_config)
new_settings = DataSettings(
llm_model='anthropic/claude-sonnet-4',
llm_base_url='https://api.anthropic.com/v1',
max_iterations=75,
llm_api_key=SecretStr('test-api-key'),
)
# Act
with patch('storage.saas_settings_store.a_session_maker', async_session_maker):
await store.store(new_settings)
# Assert - verify org's default fields were updated
with session_maker() as session:
result = session.execute(select(Org).where(Org.id == org_id))
org = result.scalars().first()
assert org is not None
assert org.default_llm_model == 'anthropic/claude-sonnet-4'
assert org.default_llm_base_url == 'https://api.anthropic.com/v1'
assert org.default_max_iterations == 75

View File

@@ -0,0 +1 @@
# Tests for enterprise server utils

View File

@@ -0,0 +1,425 @@
"""Tests for URL utility functions that prevent URL hijacking attacks."""
from unittest.mock import MagicMock, patch
import pytest
class TestGetWebUrl:
"""Tests for get_web_url function."""
@pytest.fixture
def mock_request(self):
"""Create a mock FastAPI request object."""
request = MagicMock()
request.url = MagicMock()
return request
def test_configured_web_url_is_used(self, mock_request):
"""When web_url is configured, it should be used instead of request URL."""
from server.utils.url_utils import get_web_url
mock_request.url.hostname = 'evil-attacker.com'
mock_request.url.netloc = 'evil-attacker.com:443'
mock_config = MagicMock()
mock_config.web_url = 'https://app.all-hands.dev'
with patch(
'server.utils.url_utils.get_global_config', return_value=mock_config
):
result = get_web_url(mock_request)
assert result == 'https://app.all-hands.dev'
# Should not use any info from the potentially poisoned request
assert 'evil-attacker.com' not in result
def test_configured_web_url_trailing_slash_stripped(self, mock_request):
"""Configured web_url should have trailing slashes stripped."""
from server.utils.url_utils import get_web_url
mock_config = MagicMock()
mock_config.web_url = 'https://app.all-hands.dev/'
with patch(
'server.utils.url_utils.get_global_config', return_value=mock_config
):
result = get_web_url(mock_request)
assert result == 'https://app.all-hands.dev'
assert not result.endswith('/')
def test_unconfigured_web_url_localhost_uses_http(self, mock_request):
"""When web_url is not configured and hostname is localhost, use http."""
from server.utils.url_utils import get_web_url
mock_request.url.hostname = 'localhost'
mock_request.url.netloc = 'localhost:3000'
mock_config = MagicMock()
mock_config.web_url = None
with patch(
'server.utils.url_utils.get_global_config', return_value=mock_config
):
result = get_web_url(mock_request)
assert result == 'http://localhost:3000'
def test_unconfigured_web_url_non_localhost_uses_https(self, mock_request):
"""When web_url is not configured and hostname is not localhost, use https."""
from server.utils.url_utils import get_web_url
mock_request.url.hostname = 'example.com'
mock_request.url.netloc = 'example.com:443'
mock_config = MagicMock()
mock_config.web_url = None
with patch(
'server.utils.url_utils.get_global_config', return_value=mock_config
):
result = get_web_url(mock_request)
assert result == 'https://example.com:443'
def test_unconfigured_web_url_empty_string_fallback(self, mock_request):
"""Empty string web_url should trigger fallback."""
from server.utils.url_utils import get_web_url
mock_request.url.hostname = 'localhost'
mock_request.url.netloc = 'localhost:3000'
mock_config = MagicMock()
mock_config.web_url = ''
with patch(
'server.utils.url_utils.get_global_config', return_value=mock_config
):
result = get_web_url(mock_request)
assert result == 'http://localhost:3000'
class TestGetCookieDomain:
"""Tests for get_cookie_domain function."""
def test_production_with_configured_web_url(self):
"""In production with web_url configured, should return hostname."""
from server.utils.url_utils import get_cookie_domain
mock_config = MagicMock()
mock_config.web_url = 'https://app.all-hands.dev'
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
patch('server.utils.url_utils.IS_STAGING_ENV', False),
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
):
result = get_cookie_domain()
assert result == 'app.all-hands.dev'
def test_production_without_web_url_returns_none(self):
"""In production without web_url configured, should return None."""
from server.utils.url_utils import get_cookie_domain
mock_config = MagicMock()
mock_config.web_url = None
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
patch('server.utils.url_utils.IS_STAGING_ENV', False),
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
):
result = get_cookie_domain()
assert result is None
def test_local_env_returns_none(self):
"""In local environment, should return None for cookie domain."""
from server.utils.url_utils import get_cookie_domain
mock_config = MagicMock()
mock_config.web_url = 'https://app.all-hands.dev'
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
patch('server.utils.url_utils.IS_STAGING_ENV', False),
patch('server.utils.url_utils.IS_LOCAL_ENV', True),
):
result = get_cookie_domain()
assert result is None
def test_staging_env_returns_none(self):
"""In staging environment, should return None for cookie domain."""
from server.utils.url_utils import get_cookie_domain
mock_config = MagicMock()
mock_config.web_url = 'https://staging.all-hands.dev'
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
patch('server.utils.url_utils.IS_STAGING_ENV', True),
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
):
result = get_cookie_domain()
assert result is None
def test_feature_env_returns_none(self):
"""In feature environment, should return None for cookie domain."""
from server.utils.url_utils import get_cookie_domain
mock_config = MagicMock()
mock_config.web_url = 'https://feature-123.staging.all-hands.dev'
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', True),
patch('server.utils.url_utils.IS_STAGING_ENV', True),
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
):
result = get_cookie_domain()
assert result is None
class TestGetCookieSamesite:
"""Tests for get_cookie_samesite function."""
def test_production_with_configured_web_url_returns_strict(self):
"""In production with web_url configured, should return 'strict'."""
from server.utils.url_utils import get_cookie_samesite
mock_config = MagicMock()
mock_config.web_url = 'https://app.all-hands.dev'
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
patch('server.utils.url_utils.IS_STAGING_ENV', False),
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
):
result = get_cookie_samesite()
assert result == 'strict'
def test_production_without_web_url_returns_lax(self):
"""In production without web_url configured, should return 'lax'."""
from server.utils.url_utils import get_cookie_samesite
mock_config = MagicMock()
mock_config.web_url = None
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
patch('server.utils.url_utils.IS_STAGING_ENV', False),
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
):
result = get_cookie_samesite()
assert result == 'lax'
def test_local_env_returns_lax(self):
"""In local environment, should return 'lax'."""
from server.utils.url_utils import get_cookie_samesite
mock_config = MagicMock()
mock_config.web_url = 'http://localhost:3000'
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
patch('server.utils.url_utils.IS_STAGING_ENV', False),
patch('server.utils.url_utils.IS_LOCAL_ENV', True),
):
result = get_cookie_samesite()
assert result == 'lax'
def test_staging_env_returns_lax(self):
"""In staging environment, should return 'lax'."""
from server.utils.url_utils import get_cookie_samesite
mock_config = MagicMock()
mock_config.web_url = 'https://staging.all-hands.dev'
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
patch('server.utils.url_utils.IS_STAGING_ENV', True),
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
):
result = get_cookie_samesite()
assert result == 'lax'
def test_feature_env_returns_lax(self):
"""In feature environment, should return 'lax'."""
from server.utils.url_utils import get_cookie_samesite
mock_config = MagicMock()
mock_config.web_url = 'https://feature-xyz.staging.all-hands.dev'
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', True),
patch('server.utils.url_utils.IS_STAGING_ENV', True),
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
):
result = get_cookie_samesite()
assert result == 'lax'
def test_empty_web_url_returns_lax(self):
"""Empty web_url should be treated as unconfigured and return 'lax'."""
from server.utils.url_utils import get_cookie_samesite
mock_config = MagicMock()
mock_config.web_url = ''
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
patch('server.utils.url_utils.IS_STAGING_ENV', False),
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
):
result = get_cookie_samesite()
assert result == 'lax'
class TestSecurityScenarios:
"""Tests for security-critical scenarios."""
@pytest.fixture
def mock_request(self):
"""Create a mock FastAPI request object."""
request = MagicMock()
request.url = MagicMock()
return request
def test_header_poisoning_attack_blocked_when_configured(self, mock_request):
"""
When web_url is configured, X-Forwarded-* header poisoning should not affect
the returned URL.
"""
from server.utils.url_utils import get_web_url
# Simulate a poisoned request where attacker controls headers
mock_request.url.hostname = 'evil.com'
mock_request.url.netloc = 'evil.com:443'
mock_config = MagicMock()
mock_config.web_url = 'https://app.all-hands.dev'
with patch(
'server.utils.url_utils.get_global_config', return_value=mock_config
):
result = get_web_url(mock_request)
# Should use configured web_url, not the poisoned request data
assert result == 'https://app.all-hands.dev'
assert 'evil' not in result
def test_cookie_domain_not_set_in_dev_environments(self):
"""
Cookie domain should not be set in development environments to prevent
cookies from leaking to other subdomains.
"""
from server.utils.url_utils import get_cookie_domain
mock_config = MagicMock()
mock_config.web_url = 'https://my-feature.staging.all-hands.dev'
# Test each dev environment
for env_name, env_config in [
(
'local',
{
'IS_LOCAL_ENV': True,
'IS_STAGING_ENV': False,
'IS_FEATURE_ENV': False,
},
),
(
'staging',
{
'IS_LOCAL_ENV': False,
'IS_STAGING_ENV': True,
'IS_FEATURE_ENV': False,
},
),
(
'feature',
{'IS_LOCAL_ENV': False, 'IS_STAGING_ENV': True, 'IS_FEATURE_ENV': True},
),
]:
with (
patch(
'server.utils.url_utils.get_global_config', return_value=mock_config
),
patch(
'server.utils.url_utils.IS_FEATURE_ENV',
env_config['IS_FEATURE_ENV'],
),
patch(
'server.utils.url_utils.IS_STAGING_ENV',
env_config['IS_STAGING_ENV'],
),
patch(
'server.utils.url_utils.IS_LOCAL_ENV', env_config['IS_LOCAL_ENV']
),
):
result = get_cookie_domain()
assert result is None, f'Expected None for {env_name} environment'
def test_strict_samesite_only_in_production(self):
"""
SameSite=strict should only be set in production to ensure proper
security without breaking OAuth flows in development.
"""
from server.utils.url_utils import get_cookie_samesite
mock_config = MagicMock()
mock_config.web_url = 'https://app.all-hands.dev'
# Production should be strict
with (
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
patch('server.utils.url_utils.IS_STAGING_ENV', False),
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
):
assert get_cookie_samesite() == 'strict'
# Dev environments should be lax
for env_config in [
{'IS_LOCAL_ENV': True, 'IS_STAGING_ENV': False, 'IS_FEATURE_ENV': False},
{'IS_LOCAL_ENV': False, 'IS_STAGING_ENV': True, 'IS_FEATURE_ENV': False},
{'IS_LOCAL_ENV': False, 'IS_STAGING_ENV': True, 'IS_FEATURE_ENV': True},
]:
with (
patch(
'server.utils.url_utils.get_global_config', return_value=mock_config
),
patch(
'server.utils.url_utils.IS_FEATURE_ENV',
env_config['IS_FEATURE_ENV'],
),
patch(
'server.utils.url_utils.IS_STAGING_ENV',
env_config['IS_STAGING_ENV'],
),
patch(
'server.utils.url_utils.IS_LOCAL_ENV', env_config['IS_LOCAL_ENV']
),
):
assert get_cookie_samesite() == 'lax'

View File

@@ -1,12 +1,12 @@
{
"name": "openhands-frontend",
"version": "1.4.0",
"version": "1.5.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "openhands-frontend",
"version": "1.4.0",
"version": "1.5.0",
"dependencies": {
"@heroui/react": "2.8.7",
"@microlink/react-json-view": "^1.27.1",

View File

@@ -1,6 +1,6 @@
{
"name": "openhands-frontend",
"version": "1.4.0",
"version": "1.5.0",
"private": true,
"type": "module",
"engines": {

View File

@@ -144,7 +144,7 @@ runtime = [
[tool.poetry]
name = "openhands-ai"
version = "1.4.0"
version = "1.5.0"
description = "OpenHands: Code Less, Make More"
authors = [ "OpenHands" ]
license = "MIT"