mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
3 Commits
pr13306
...
auto/execu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff5274f60b | ||
|
|
d57e6c134c | ||
|
|
d3cc121b08 |
71
enterprise/run_automation_executor.py
Normal file
71
enterprise/run_automation_executor.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Entry point for the automation executor.
|
||||
|
||||
Usage: python -m run_automation_executor
|
||||
|
||||
This runs as a Kubernetes Deployment (long-running). It polls the automation_events
|
||||
inbox, matches events to automations, claims and executes runs, and monitors
|
||||
conversation completion.
|
||||
|
||||
Environment variables:
|
||||
OPENHANDS_API_URL Base URL for the V1 API (default: http://openhands-service:3000)
|
||||
MAX_CONCURRENT_RUNS Max concurrent runs per executor (default: 5)
|
||||
RUN_TIMEOUT_SECONDS Max time for a single run (default: 7200)
|
||||
POLL_INTERVAL_SECONDS Fallback poll interval (default: 30)
|
||||
HEARTBEAT_INTERVAL_SECONDS Heartbeat update interval (default: 60)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger('saas.automation.executor')
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
"""Configure logging, deferring to enterprise logger if available."""
|
||||
try:
|
||||
from server.logger import setup_all_loggers
|
||||
|
||||
setup_all_loggers()
|
||||
except ImportError:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s %(name)s %(levelname)s %(message)s',
|
||||
stream=sys.stdout,
|
||||
)
|
||||
|
||||
|
||||
def _install_signal_handlers(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Install signal handlers for graceful shutdown."""
|
||||
from services.automation_executor import request_shutdown
|
||||
|
||||
def _handle_signal(signum: int, _frame: object) -> None:
|
||||
sig_name = signal.Signals(signum).name
|
||||
logger.info('Received %s, initiating graceful shutdown...', sig_name)
|
||||
request_shutdown()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
signal.signal(sig, _handle_signal)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
from services.automation_executor import executor_main
|
||||
|
||||
await executor_main()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_setup_logging()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
_install_signal_handlers(loop)
|
||||
|
||||
logger.info('Starting automation executor')
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info('Interrupted by user')
|
||||
finally:
|
||||
loop.close()
|
||||
logger.info('Automation executor process exiting')
|
||||
0
enterprise/services/__init__.py
Normal file
0
enterprise/services/__init__.py
Normal file
555
enterprise/services/automation_executor.py
Normal file
555
enterprise/services/automation_executor.py
Normal file
@@ -0,0 +1,555 @@
|
||||
"""Automation executor — processes events, claims and executes runs.
|
||||
|
||||
The executor is a long-running process with three phases:
|
||||
1. Process inbox: match NEW events to automations, create PENDING runs
|
||||
2. Claim and execute: claim PENDING runs, submit to V1 API, heartbeat
|
||||
3. Stale recovery: recover RUNNING runs with expired heartbeats
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from services.openhands_api_client import OpenHandsAPIClient
|
||||
from sqlalchemy import or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.automation import Automation, AutomationRun
|
||||
from storage.automation_event import AutomationEvent
|
||||
|
||||
logger = logging.getLogger('saas.automation.executor')
|
||||
|
||||
# Environment-configurable settings
|
||||
POLL_INTERVAL_SECONDS = float(os.getenv('POLL_INTERVAL_SECONDS', '30'))
|
||||
HEARTBEAT_INTERVAL_SECONDS = float(os.getenv('HEARTBEAT_INTERVAL_SECONDS', '60'))
|
||||
RUN_TIMEOUT_SECONDS = float(os.getenv('RUN_TIMEOUT_SECONDS', '7200'))
|
||||
MAX_CONCURRENT_RUNS = int(os.getenv('MAX_CONCURRENT_RUNS', '5'))
|
||||
STALE_THRESHOLD_MINUTES = 5
|
||||
MAX_EVENTS_PER_BATCH = 50
|
||||
MAX_RETRIES_DEFAULT = 3
|
||||
|
||||
# Terminal conversation statuses
|
||||
TERMINAL_STATUSES = frozenset({'STOPPED', 'ERROR', 'COMPLETED', 'CANCELLED'})
|
||||
|
||||
# Shutdown flag — set by signal handlers
|
||||
_shutdown_event: asyncio.Event | None = None
|
||||
|
||||
# Background task tracking for graceful shutdown
|
||||
_pending_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def get_shutdown_event() -> asyncio.Event:
|
||||
global _shutdown_event
|
||||
if _shutdown_event is None:
|
||||
_shutdown_event = asyncio.Event()
|
||||
return _shutdown_event
|
||||
|
||||
|
||||
def should_continue() -> bool:
|
||||
return not get_shutdown_event().is_set()
|
||||
|
||||
|
||||
def request_shutdown() -> None:
|
||||
get_shutdown_event().set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 1: Process inbox (event matching)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def find_matching_automations(
|
||||
session: AsyncSession, event: AutomationEvent
|
||||
) -> list[Automation]:
|
||||
"""Find automations that match the given event.
|
||||
|
||||
Phase 1 supports cron and manual triggers only — both carry
|
||||
``automation_id`` in the event payload.
|
||||
"""
|
||||
source_type = event.source_type
|
||||
payload = event.payload
|
||||
if payload is None:
|
||||
logger.error('Event %s has None payload — possible data corruption', event.id)
|
||||
return []
|
||||
|
||||
if source_type in ('cron', 'manual'):
|
||||
automation_id = payload.get('automation_id')
|
||||
if not automation_id:
|
||||
logger.warning(
|
||||
'Event %s (source=%s) missing automation_id in payload',
|
||||
event.id,
|
||||
source_type,
|
||||
)
|
||||
return []
|
||||
|
||||
result = await session.execute(
|
||||
select(Automation).where(
|
||||
Automation.id == automation_id,
|
||||
Automation.enabled.is_(True),
|
||||
)
|
||||
)
|
||||
automation = result.scalar_one_or_none()
|
||||
return [automation] if automation else []
|
||||
|
||||
logger.debug('Unhandled event source_type=%s for event %s', source_type, event.id)
|
||||
return []
|
||||
|
||||
|
||||
async def process_new_events(session: AsyncSession) -> int:
|
||||
"""Claim NEW events from inbox, match to automations, create runs.
|
||||
|
||||
Returns the number of events processed.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(AutomationEvent)
|
||||
.where(AutomationEvent.status == 'NEW')
|
||||
.order_by(AutomationEvent.created_at)
|
||||
.limit(MAX_EVENTS_PER_BATCH)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
events = list(result.scalars())
|
||||
|
||||
processed = 0
|
||||
for event in events:
|
||||
try:
|
||||
automations = await find_matching_automations(session, event)
|
||||
if not automations:
|
||||
event.status = 'NO_MATCH'
|
||||
event.processed_at = utc_now()
|
||||
else:
|
||||
for automation in automations:
|
||||
run = AutomationRun(
|
||||
id=uuid4().hex,
|
||||
automation_id=automation.id,
|
||||
event_id=event.id,
|
||||
status='PENDING',
|
||||
event_payload=event.payload,
|
||||
)
|
||||
session.add(run)
|
||||
event.status = 'PROCESSED'
|
||||
event.processed_at = utc_now()
|
||||
processed += 1
|
||||
except Exception as e:
|
||||
logger.exception('Error processing event %s', event.id)
|
||||
event.status = 'ERROR'
|
||||
event.error_detail = f'Failed during event matching: {type(e).__name__}: {e}'
|
||||
event.processed_at = utc_now()
|
||||
|
||||
if processed:
|
||||
await session.commit()
|
||||
logger.info('Processed %d events', processed)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2: Claim and execute runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def resolve_user_api_key(session: AsyncSession, user_id: str) -> str | None:
|
||||
"""Look up a user's API key from the api_keys table.
|
||||
|
||||
Returns the first active key found, or None.
|
||||
"""
|
||||
from storage.api_key import ApiKey
|
||||
|
||||
result = await session.execute(
|
||||
select(ApiKey.key).where(ApiKey.user_id == user_id).limit(1)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return row
|
||||
|
||||
|
||||
async def download_automation_file(file_store_key: str) -> bytes:
|
||||
"""Download the automation .py file from object storage."""
|
||||
try:
|
||||
from openhands.server.shared import file_store
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
'file_store is not available — ensure the enterprise server '
|
||||
'has been initialised before calling download_automation_file'
|
||||
) from exc
|
||||
|
||||
content = file_store.read(file_store_key)
|
||||
if isinstance(content, str):
|
||||
return content.encode('utf-8')
|
||||
return content
|
||||
|
||||
|
||||
def is_terminal(conversation: dict) -> bool:
|
||||
"""Check if a conversation has reached a terminal status."""
|
||||
status = (conversation.get('status') or '').upper()
|
||||
return status in TERMINAL_STATUSES
|
||||
|
||||
|
||||
async def _prepare_run(
|
||||
run: AutomationRun,
|
||||
automation: Automation,
|
||||
session_factory: object,
|
||||
) -> tuple[str, bytes]:
|
||||
"""Resolve the user's API key and download the automation file.
|
||||
|
||||
Returns:
|
||||
(api_key, automation_file) tuple ready for submission.
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key is found.
|
||||
RuntimeError: If file_store is unavailable.
|
||||
"""
|
||||
async with session_factory() as key_session:
|
||||
api_key = await resolve_user_api_key(key_session, automation.user_id)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(f'No API key found for user {automation.user_id}')
|
||||
|
||||
automation_file = await download_automation_file(automation.file_store_key)
|
||||
return api_key, automation_file
|
||||
|
||||
|
||||
async def _monitor_conversation(
|
||||
run: AutomationRun,
|
||||
conversation_id: str,
|
||||
api_client: OpenHandsAPIClient,
|
||||
api_key: str,
|
||||
session_factory: object,
|
||||
) -> bool:
|
||||
"""Monitor a conversation until completion or timeout.
|
||||
|
||||
Returns True if completed successfully, False if shutdown requested.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the run exceeds RUN_TIMEOUT_SECONDS.
|
||||
"""
|
||||
start_time = utc_now()
|
||||
while should_continue():
|
||||
elapsed = (utc_now() - start_time).total_seconds()
|
||||
if elapsed > RUN_TIMEOUT_SECONDS:
|
||||
raise TimeoutError(f'Run exceeded {RUN_TIMEOUT_SECONDS}s timeout')
|
||||
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
|
||||
|
||||
# Update heartbeat
|
||||
async with session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, run.id)
|
||||
if run_obj:
|
||||
run_obj.heartbeat_at = utc_now()
|
||||
await session.commit()
|
||||
|
||||
# Check conversation status
|
||||
conversation = (
|
||||
await api_client.get_conversation(api_key, conversation_id) or {}
|
||||
)
|
||||
if is_terminal(conversation):
|
||||
return True
|
||||
|
||||
return False # shutdown requested
|
||||
|
||||
|
||||
async def _submit_and_monitor(
|
||||
run: AutomationRun,
|
||||
api_key: str,
|
||||
automation_file: bytes,
|
||||
automation: Automation,
|
||||
api_client: OpenHandsAPIClient,
|
||||
session_factory: object,
|
||||
) -> None:
|
||||
"""Submit the automation to the V1 API and monitor until completion.
|
||||
|
||||
Updates the run's conversation_id, sends heartbeats, and marks the
|
||||
final status when the conversation reaches a terminal state.
|
||||
"""
|
||||
conversation = await api_client.start_conversation(
|
||||
api_key=api_key,
|
||||
automation_file=automation_file,
|
||||
title=f'Automation: {automation.name}',
|
||||
event_payload=run.event_payload,
|
||||
)
|
||||
|
||||
conversation_id = conversation.get('app_conversation_id') or conversation.get(
|
||||
'conversation_id'
|
||||
)
|
||||
|
||||
# Persist conversation ID
|
||||
async with session_factory() as update_session:
|
||||
run_obj = await update_session.get(AutomationRun, run.id)
|
||||
if run_obj:
|
||||
run_obj.conversation_id = conversation_id
|
||||
await update_session.commit()
|
||||
|
||||
# Monitor with heartbeats
|
||||
completed = await _monitor_conversation(
|
||||
run, conversation_id, api_client, api_key, session_factory
|
||||
)
|
||||
|
||||
# Update final status
|
||||
async with session_factory() as final_session:
|
||||
run_obj = await final_session.get(AutomationRun, run.id)
|
||||
if run_obj:
|
||||
if not completed:
|
||||
# Leave as RUNNING — stale recovery will handle it if needed.
|
||||
# The conversation may still be running on the API side.
|
||||
logger.info(
|
||||
'Run %s left as RUNNING due to executor shutdown', run.id
|
||||
)
|
||||
else:
|
||||
run_obj.status = 'COMPLETED'
|
||||
run_obj.completed_at = utc_now()
|
||||
logger.info('Run %s completed successfully', run.id)
|
||||
await final_session.commit()
|
||||
|
||||
|
||||
async def execute_run(
|
||||
run: AutomationRun,
|
||||
automation: Automation,
|
||||
api_client: OpenHandsAPIClient,
|
||||
session_factory: object,
|
||||
) -> None:
|
||||
"""Execute a single automation run end-to-end.
|
||||
|
||||
Orchestrates preparation (API key + file download) and submission/monitoring.
|
||||
On failure, marks the run for retry or dead-letter.
|
||||
"""
|
||||
try:
|
||||
api_key, automation_file = await _prepare_run(
|
||||
run, automation, session_factory
|
||||
)
|
||||
await _submit_and_monitor(
|
||||
run, api_key, automation_file, automation, api_client, session_factory
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception('Run %s failed: %s', run.id, e)
|
||||
await _mark_run_failed(run, str(e), session_factory)
|
||||
|
||||
|
||||
async def _mark_run_failed(
|
||||
run: AutomationRun, error: str, session_factory: object
|
||||
) -> None:
|
||||
"""Mark a run as FAILED or return to PENDING for retry."""
|
||||
async with session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, run.id)
|
||||
if not run_obj:
|
||||
return
|
||||
|
||||
run_obj.retry_count = (run_obj.retry_count or 0) + 1
|
||||
run_obj.error_detail = error
|
||||
|
||||
if run_obj.retry_count >= (run_obj.max_retries or MAX_RETRIES_DEFAULT):
|
||||
run_obj.status = 'DEAD_LETTER'
|
||||
run_obj.completed_at = utc_now()
|
||||
logger.error(
|
||||
'Run %s moved to DEAD_LETTER after %d retries',
|
||||
run.id,
|
||||
run_obj.retry_count,
|
||||
)
|
||||
else:
|
||||
run_obj.status = 'PENDING'
|
||||
run_obj.claimed_by = None
|
||||
backoff_seconds = 30 * (2 ** (run_obj.retry_count - 1))
|
||||
run_obj.next_retry_at = utc_now() + timedelta(seconds=backoff_seconds)
|
||||
logger.warning(
|
||||
'Run %s returned to PENDING, retry %d/%d in %ds',
|
||||
run.id,
|
||||
run_obj.retry_count,
|
||||
run_obj.max_retries or MAX_RETRIES_DEFAULT,
|
||||
backoff_seconds,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def claim_and_execute_runs(
|
||||
session: AsyncSession,
|
||||
executor_id: str,
|
||||
api_client: OpenHandsAPIClient,
|
||||
session_factory: object,
|
||||
) -> bool:
|
||||
"""Claim a PENDING run and start executing it.
|
||||
|
||||
Returns True if a run was claimed, False otherwise.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(AutomationRun)
|
||||
.where(
|
||||
AutomationRun.status == 'PENDING',
|
||||
or_(
|
||||
AutomationRun.next_retry_at.is_(None),
|
||||
AutomationRun.next_retry_at <= utc_now(),
|
||||
),
|
||||
)
|
||||
.order_by(AutomationRun.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
run = result.scalar_one_or_none()
|
||||
if not run:
|
||||
return False
|
||||
|
||||
# Claim the run
|
||||
run.status = 'RUNNING'
|
||||
run.claimed_by = executor_id
|
||||
run.claimed_at = utc_now()
|
||||
run.heartbeat_at = utc_now()
|
||||
run.started_at = utc_now()
|
||||
await session.commit()
|
||||
|
||||
# Load automation for the run
|
||||
auto_result = await session.execute(
|
||||
select(Automation).where(Automation.id == run.automation_id)
|
||||
)
|
||||
automation = auto_result.scalar_one_or_none()
|
||||
if not automation:
|
||||
logger.error('Automation %s not found for run %s', run.automation_id, run.id)
|
||||
await _mark_run_failed(
|
||||
run, f'Automation {run.automation_id} not found', session_factory
|
||||
)
|
||||
return True
|
||||
|
||||
# Execute in background (long-running) with task tracking
|
||||
task = asyncio.create_task(
|
||||
execute_run(run, automation, api_client, session_factory),
|
||||
name=f'execute-run-{run.id}',
|
||||
)
|
||||
_pending_tasks.add(task)
|
||||
task.add_done_callback(_pending_tasks.discard)
|
||||
|
||||
logger.info(
|
||||
'Claimed run %s (automation=%s) by executor %s',
|
||||
run.id,
|
||||
run.automation_id,
|
||||
executor_id,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 3: Stale run recovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def recover_stale_runs(session: AsyncSession) -> int:
|
||||
"""Mark RUNNING runs with expired heartbeats as PENDING for retry.
|
||||
|
||||
Returns the number of recovered runs.
|
||||
"""
|
||||
stale_threshold = utc_now() - timedelta(minutes=STALE_THRESHOLD_MINUTES)
|
||||
timeout_threshold = utc_now() - timedelta(seconds=RUN_TIMEOUT_SECONDS)
|
||||
|
||||
# Recover stale runs (heartbeat expired)
|
||||
result = await session.execute(
|
||||
update(AutomationRun)
|
||||
.where(
|
||||
AutomationRun.status == 'RUNNING',
|
||||
AutomationRun.heartbeat_at < stale_threshold,
|
||||
AutomationRun.heartbeat_at >= timeout_threshold,
|
||||
)
|
||||
.values(
|
||||
status='PENDING',
|
||||
claimed_by=None,
|
||||
retry_count=AutomationRun.retry_count + 1,
|
||||
next_retry_at=utc_now() + timedelta(seconds=30),
|
||||
)
|
||||
.returning(AutomationRun.id)
|
||||
)
|
||||
recovered_rows = result.fetchall()
|
||||
|
||||
# Mark truly timed-out runs as DEAD_LETTER
|
||||
timeout_result = await session.execute(
|
||||
update(AutomationRun)
|
||||
.where(
|
||||
AutomationRun.status == 'RUNNING',
|
||||
AutomationRun.heartbeat_at < timeout_threshold,
|
||||
)
|
||||
.values(
|
||||
status='DEAD_LETTER',
|
||||
error_detail='Run exceeded timeout',
|
||||
completed_at=utc_now(),
|
||||
)
|
||||
.returning(AutomationRun.id)
|
||||
)
|
||||
timed_out_rows = timeout_result.fetchall()
|
||||
|
||||
await session.commit()
|
||||
|
||||
recovered_count = len(recovered_rows)
|
||||
timed_out_count = len(timed_out_rows)
|
||||
|
||||
if recovered_count:
|
||||
logger.warning('Recovered %d stale automation runs', recovered_count)
|
||||
if timed_out_count:
|
||||
logger.warning(
|
||||
'Marked %d automation runs as DEAD_LETTER (timeout)', timed_out_count
|
||||
)
|
||||
|
||||
return recovered_count + timed_out_count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main executor loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def executor_main(session_factory: object | None = None) -> None:
|
||||
"""Main executor loop.
|
||||
|
||||
Args:
|
||||
session_factory: Async context manager that yields AsyncSession instances.
|
||||
If None, uses the default ``a_session_maker`` from database module.
|
||||
"""
|
||||
if session_factory is None:
|
||||
from storage.database import a_session_maker
|
||||
|
||||
session_factory = a_session_maker
|
||||
|
||||
executor_id = f'executor-{socket.gethostname()}-{os.getpid()}'
|
||||
api_url = os.getenv('OPENHANDS_API_URL', 'http://openhands-service:3000')
|
||||
api_client = OpenHandsAPIClient(base_url=api_url)
|
||||
|
||||
logger.info(
|
||||
'Automation executor %s starting (api_url=%s, poll=%ss, heartbeat=%ss)',
|
||||
executor_id,
|
||||
api_url,
|
||||
POLL_INTERVAL_SECONDS,
|
||||
HEARTBEAT_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
while should_continue():
|
||||
try:
|
||||
async with session_factory() as session:
|
||||
await process_new_events(session)
|
||||
|
||||
async with session_factory() as session:
|
||||
await claim_and_execute_runs(
|
||||
session, executor_id, api_client, session_factory
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
await recover_stale_runs(session)
|
||||
|
||||
except Exception:
|
||||
logger.exception('Error in executor main loop iteration')
|
||||
|
||||
# Wait for next poll interval (or early wakeup on shutdown)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
get_shutdown_event().wait(),
|
||||
timeout=POLL_INTERVAL_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal — poll interval elapsed
|
||||
|
||||
finally:
|
||||
if _pending_tasks:
|
||||
logger.info(
|
||||
'Waiting for %d running tasks to complete...', len(_pending_tasks)
|
||||
)
|
||||
await asyncio.gather(*_pending_tasks, return_exceptions=True)
|
||||
await api_client.close()
|
||||
logger.info('Automation executor %s shut down', executor_id)
|
||||
93
enterprise/services/openhands_api_client.py
Normal file
93
enterprise/services/openhands_api_client.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""HTTP client for the main OpenHands V1 API (internal cluster calls).
|
||||
|
||||
Used by the automation executor to create and monitor conversations
|
||||
in the main OpenHands server.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger('saas.automation.api_client')
|
||||
|
||||
|
||||
def _raise_with_body(resp: httpx.Response) -> None:
|
||||
"""Call raise_for_status, enriching the error with the response body."""
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_body = resp.text[:500] if resp.text else 'no response body'
|
||||
raise httpx.HTTPStatusError(
|
||||
f'{e.args[0]} — Response: {error_body}',
|
||||
request=e.request,
|
||||
response=e.response,
|
||||
) from e
|
||||
|
||||
|
||||
class OpenHandsAPIClient:
|
||||
"""Async HTTP client for the OpenHands V1 API."""
|
||||
|
||||
def __init__(self, base_url: str = 'http://openhands-service:3000'):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0)
|
||||
|
||||
async def start_conversation(
|
||||
self,
|
||||
api_key: str,
|
||||
automation_file: bytes,
|
||||
title: str,
|
||||
event_payload: dict | None = None,
|
||||
) -> dict:
|
||||
"""Submit an SDK script for sandboxed execution via V1 API.
|
||||
|
||||
Args:
|
||||
api_key: User's API key for authentication.
|
||||
automation_file: Raw bytes of the .py automation script.
|
||||
title: Display title for the conversation.
|
||||
event_payload: Optional trigger event data (injected as env var).
|
||||
|
||||
Returns:
|
||||
Parsed JSON response containing conversation details.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the API returns a non-2xx status.
|
||||
"""
|
||||
resp = await self.client.post(
|
||||
'/api/v1/app-conversations',
|
||||
json={
|
||||
'automation_file': base64.b64encode(automation_file).decode(),
|
||||
'trigger': 'automation',
|
||||
'title': title,
|
||||
'event_payload': event_payload,
|
||||
},
|
||||
headers={'Authorization': f'Bearer {api_key}'},
|
||||
)
|
||||
_raise_with_body(resp)
|
||||
return resp.json()
|
||||
|
||||
async def get_conversation(self, api_key: str, conversation_id: str) -> dict | None:
|
||||
"""Get conversation status.
|
||||
|
||||
Args:
|
||||
api_key: User's API key for authentication.
|
||||
conversation_id: The conversation ID to look up.
|
||||
|
||||
Returns:
|
||||
Conversation data dict, or None if not found.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the API returns a non-2xx status.
|
||||
"""
|
||||
resp = await self.client.get(
|
||||
'/api/v1/app-conversations',
|
||||
params={'ids': [conversation_id]},
|
||||
headers={'Authorization': f'Bearer {api_key}'},
|
||||
)
|
||||
_raise_with_body(resp)
|
||||
conversations = resp.json()
|
||||
return conversations[0] if conversations else None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying HTTP client."""
|
||||
await self.client.aclose()
|
||||
77
enterprise/storage/automation.py
Normal file
77
enterprise/storage/automation.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""SQLAlchemy models for automations and automation runs.
|
||||
|
||||
Stub for Task 1 (Data Foundation). These models will be replaced when Task 1
|
||||
is merged into automations-phase1.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import JSON
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class Automation(Base):
|
||||
__tablename__ = 'automations'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False, index=True)
|
||||
org_id = Column(String, nullable=True, index=True)
|
||||
|
||||
name = Column(String, nullable=False)
|
||||
enabled = Column(Boolean, nullable=False, server_default=text('true'))
|
||||
|
||||
config = Column(JSON, nullable=False)
|
||||
trigger_type = Column(String, nullable=False)
|
||||
file_store_key = Column(String, nullable=False)
|
||||
|
||||
last_triggered_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
|
||||
runs = relationship('AutomationRun', back_populates='automation')
|
||||
|
||||
|
||||
class AutomationRun(Base):
|
||||
__tablename__ = 'automation_runs'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
automation_id = Column(
|
||||
String, ForeignKey('automations.id', ondelete='CASCADE'), nullable=False
|
||||
)
|
||||
event_id = Column(Integer, ForeignKey('automation_events.id'), nullable=True)
|
||||
conversation_id = Column(String, nullable=True)
|
||||
status = Column(String, nullable=False, server_default=text("'PENDING'"))
|
||||
claimed_by = Column(String, nullable=True)
|
||||
claimed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
heartbeat_at = Column(DateTime(timezone=True), nullable=True)
|
||||
retry_count = Column(Integer, nullable=False, server_default=text('0'))
|
||||
max_retries = Column(Integer, nullable=False, server_default=text('3'))
|
||||
next_retry_at = Column(DateTime(timezone=True), nullable=True)
|
||||
event_payload = Column(JSON, nullable=True)
|
||||
error_detail = Column(Text, nullable=True)
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
|
||||
automation = relationship('Automation', back_populates='runs')
|
||||
27
enterprise/storage/automation_event.py
Normal file
27
enterprise/storage/automation_event.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""SQLAlchemy model for automation events (the inbox).
|
||||
|
||||
Stub for Task 1 (Data Foundation). This model will be replaced when Task 1
|
||||
is merged into automations-phase1.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String, Text, text
|
||||
from sqlalchemy.types import JSON
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class AutomationEvent(Base):
|
||||
__tablename__ = 'automation_events'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
source_type = Column(String, nullable=False)
|
||||
payload = Column(JSON, nullable=False)
|
||||
metadata_ = Column('metadata', JSON, nullable=True)
|
||||
dedup_key = Column(String, nullable=False, unique=True)
|
||||
status = Column(String, nullable=False, server_default=text("'NEW'"))
|
||||
error_detail = Column(Text, nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
processed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
0
enterprise/tests/unit/services/__init__.py
Normal file
0
enterprise/tests/unit/services/__init__.py
Normal file
68
enterprise/tests/unit/services/conftest.py
Normal file
68
enterprise/tests/unit/services/conftest.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Shared fixtures for services tests.
|
||||
|
||||
Note: We pre-load ``storage`` as a namespace package to avoid the heavy
|
||||
``storage/__init__.py`` that imports the entire enterprise model graph.
|
||||
This must happen *before* any ``from storage.…`` import.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Prevent storage/__init__.py from loading the full model graph.
|
||||
# We only need the lightweight automation models for these tests.
|
||||
if 'storage' not in sys.modules:
|
||||
import pathlib
|
||||
|
||||
_storage_dir = str(pathlib.Path(__file__).resolve().parents[3] / 'storage')
|
||||
_mod = types.ModuleType('storage')
|
||||
_mod.__path__ = [_storage_dir]
|
||||
sys.modules['storage'] = _mod
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.automation import Automation, AutomationRun # noqa: F401
|
||||
from storage.automation_event import AutomationEvent # noqa: F401
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_factory(async_engine):
|
||||
"""Create an async session factory that yields context-managed sessions."""
|
||||
factory = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _session_ctx():
|
||||
async with factory() as session:
|
||||
yield session
|
||||
|
||||
return _session_ctx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_session_factory):
|
||||
"""Create a single async session for testing."""
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
624
enterprise/tests/unit/services/test_automation_executor.py
Normal file
624
enterprise/tests/unit/services/test_automation_executor.py
Normal file
@@ -0,0 +1,624 @@
|
||||
"""Tests for the automation executor.
|
||||
|
||||
Uses real SQLite database operations for event processing, run claiming,
|
||||
and stale run recovery. HTTP calls to the V1 API are mocked.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from services.automation_executor import (
|
||||
_mark_run_failed,
|
||||
claim_and_execute_runs,
|
||||
find_matching_automations,
|
||||
is_terminal,
|
||||
process_new_events,
|
||||
recover_stale_runs,
|
||||
utc_now,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from storage.automation import Automation, AutomationRun
|
||||
from storage.automation_event import AutomationEvent
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_automation(
|
||||
automation_id: str = 'auto-1',
|
||||
user_id: str = 'user-1',
|
||||
enabled: bool = True,
|
||||
trigger_type: str = 'cron',
|
||||
name: str = 'Test Automation',
|
||||
) -> Automation:
|
||||
return Automation(
|
||||
id=automation_id,
|
||||
user_id=user_id,
|
||||
org_id='org-1',
|
||||
name=name,
|
||||
enabled=enabled,
|
||||
config={'triggers': {'cron': {'schedule': '0 9 * * 5'}}},
|
||||
trigger_type=trigger_type,
|
||||
file_store_key=f'automations/{automation_id}/script.py',
|
||||
)
|
||||
|
||||
|
||||
def make_event(
|
||||
source_type: str = 'cron',
|
||||
payload: dict | None = None,
|
||||
status: str = 'NEW',
|
||||
dedup_key: str | None = None,
|
||||
) -> AutomationEvent:
|
||||
return AutomationEvent(
|
||||
source_type=source_type,
|
||||
payload=payload or {'automation_id': 'auto-1'},
|
||||
dedup_key=dedup_key or f'dedup-{uuid4().hex[:8]}',
|
||||
status=status,
|
||||
created_at=utc_now(),
|
||||
)
|
||||
|
||||
|
||||
def make_run(
|
||||
run_id: str | None = None,
|
||||
automation_id: str = 'auto-1',
|
||||
status: str = 'PENDING',
|
||||
claimed_by: str | None = None,
|
||||
heartbeat_at: datetime | None = None,
|
||||
retry_count: int = 0,
|
||||
max_retries: int = 3,
|
||||
next_retry_at: datetime | None = None,
|
||||
) -> AutomationRun:
|
||||
return AutomationRun(
|
||||
id=run_id or uuid4().hex,
|
||||
automation_id=automation_id,
|
||||
status=status,
|
||||
claimed_by=claimed_by,
|
||||
heartbeat_at=heartbeat_at,
|
||||
retry_count=retry_count,
|
||||
max_retries=max_retries,
|
||||
next_retry_at=next_retry_at,
|
||||
event_payload={'automation_id': automation_id},
|
||||
created_at=utc_now(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_automations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_cron_event(async_session):
|
||||
"""Cron events match by automation_id in payload."""
|
||||
automation = make_automation()
|
||||
async_session.add(automation)
|
||||
await async_session.commit()
|
||||
|
||||
event = make_event(source_type='cron', payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == 'auto-1'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_manual_event(async_session):
|
||||
"""Manual events also match by automation_id in payload."""
|
||||
automation = make_automation()
|
||||
async_session.add(automation)
|
||||
await async_session.commit()
|
||||
|
||||
event = make_event(source_type='manual', payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == 'auto-1'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_disabled_automation(async_session):
|
||||
"""Disabled automations are not matched."""
|
||||
automation = make_automation(enabled=False)
|
||||
async_session.add(automation)
|
||||
await async_session.commit()
|
||||
|
||||
event = make_event(payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_missing_automation_id(async_session):
|
||||
"""Events without automation_id in payload return empty list."""
|
||||
event = make_event(payload={'something_else': 'value'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_nonexistent_automation(async_session):
|
||||
"""Events referencing a non-existent automation return empty list."""
|
||||
event = make_event(payload={'automation_id': 'nonexistent'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_unknown_source_type(async_session):
|
||||
"""Unknown source types return empty list."""
|
||||
event = make_event(source_type='unknown', payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_new_events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_creates_runs(async_session):
|
||||
"""Processing NEW events creates PENDING runs and marks events PROCESSED."""
|
||||
automation = make_automation()
|
||||
event = make_event(payload={'automation_id': 'auto-1'})
|
||||
async_session.add_all([automation, event])
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 1
|
||||
|
||||
# Event should be PROCESSED
|
||||
await async_session.refresh(event)
|
||||
assert event.status == 'PROCESSED'
|
||||
assert event.processed_at is not None
|
||||
|
||||
# A run should have been created
|
||||
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 1
|
||||
assert runs[0].automation_id == 'auto-1'
|
||||
assert runs[0].status == 'PENDING'
|
||||
assert runs[0].event_payload == {'automation_id': 'auto-1'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_no_match(async_session):
|
||||
"""Events with no matching automation are marked NO_MATCH."""
|
||||
event = make_event(payload={'automation_id': 'nonexistent'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 1
|
||||
|
||||
await async_session.refresh(event)
|
||||
assert event.status == 'NO_MATCH'
|
||||
assert event.processed_at is not None
|
||||
|
||||
# No runs created
|
||||
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_skips_processed(async_session):
|
||||
"""Already processed events are not re-processed."""
|
||||
event = make_event(status='PROCESSED')
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_multiple_events(async_session):
|
||||
"""Multiple NEW events are processed in one batch."""
|
||||
auto1 = make_automation(automation_id='auto-1')
|
||||
auto2 = make_automation(automation_id='auto-2', name='Auto 2')
|
||||
event1 = make_event(payload={'automation_id': 'auto-1'}, dedup_key='dedup-1')
|
||||
event2 = make_event(payload={'automation_id': 'auto-2'}, dedup_key='dedup-2')
|
||||
event3 = make_event(payload={'automation_id': 'nonexistent'}, dedup_key='dedup-3')
|
||||
async_session.add_all([auto1, auto2, event1, event2, event3])
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 3
|
||||
|
||||
# Two runs created (for auto-1 and auto-2), none for nonexistent
|
||||
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 2
|
||||
|
||||
await async_session.refresh(event1)
|
||||
await async_session.refresh(event2)
|
||||
await async_session.refresh(event3)
|
||||
assert event1.status == 'PROCESSED'
|
||||
assert event2.status == 'PROCESSED'
|
||||
assert event3.status == 'NO_MATCH'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# claim_and_execute_runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_claims_pending(
|
||||
async_session, async_session_factory
|
||||
):
|
||||
"""Claims a PENDING run and transitions to RUNNING."""
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='run-1')
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is True
|
||||
|
||||
await async_session.refresh(run)
|
||||
assert run.status == 'RUNNING'
|
||||
assert run.claimed_by == 'executor-test-1'
|
||||
assert run.claimed_at is not None
|
||||
assert run.heartbeat_at is not None
|
||||
assert run.started_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_no_pending(async_session, async_session_factory):
|
||||
"""Returns False when no PENDING runs exist."""
|
||||
api_client = AsyncMock()
|
||||
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_respects_next_retry_at(
|
||||
async_session, async_session_factory
|
||||
):
|
||||
"""Runs with future next_retry_at are not claimed."""
|
||||
automation = make_automation()
|
||||
run = make_run(
|
||||
run_id='run-retry',
|
||||
next_retry_at=utc_now() + timedelta(hours=1),
|
||||
)
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_past_retry_at(
|
||||
async_session, async_session_factory
|
||||
):
|
||||
"""Runs with past next_retry_at are claimable."""
|
||||
automation = make_automation()
|
||||
run = make_run(
|
||||
run_id='run-retry-past',
|
||||
next_retry_at=utc_now() - timedelta(minutes=5),
|
||||
)
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_skips_running_runs(async_session, async_session_factory):
|
||||
"""RUNNING runs are not claimed."""
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='run-running', status='RUNNING', claimed_by='other-executor')
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# recover_stale_runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_recovers_stale(async_session):
|
||||
"""RUNNING runs with expired heartbeats are recovered to PENDING."""
|
||||
automation = make_automation()
|
||||
stale_run = make_run(
|
||||
run_id='stale-1',
|
||||
status='RUNNING',
|
||||
claimed_by='crashed-executor',
|
||||
heartbeat_at=utc_now() - timedelta(minutes=10),
|
||||
retry_count=0,
|
||||
)
|
||||
async_session.add_all([automation, stale_run])
|
||||
await async_session.commit()
|
||||
|
||||
count = await recover_stale_runs(async_session)
|
||||
|
||||
assert count >= 1
|
||||
|
||||
await async_session.refresh(stale_run)
|
||||
assert stale_run.status == 'PENDING'
|
||||
assert stale_run.claimed_by is None
|
||||
assert stale_run.retry_count == 1
|
||||
assert stale_run.next_retry_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_ignores_fresh(async_session):
|
||||
"""RUNNING runs with recent heartbeats are not recovered."""
|
||||
automation = make_automation()
|
||||
fresh_run = make_run(
|
||||
run_id='fresh-1',
|
||||
status='RUNNING',
|
||||
claimed_by='active-executor',
|
||||
heartbeat_at=utc_now() - timedelta(seconds=30),
|
||||
)
|
||||
async_session.add_all([automation, fresh_run])
|
||||
await async_session.commit()
|
||||
|
||||
count = await recover_stale_runs(async_session)
|
||||
|
||||
assert count == 0
|
||||
|
||||
await async_session.refresh(fresh_run)
|
||||
assert fresh_run.status == 'RUNNING'
|
||||
assert fresh_run.claimed_by == 'active-executor'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_ignores_pending(async_session):
|
||||
"""PENDING runs are not affected by recovery."""
|
||||
automation = make_automation()
|
||||
pending_run = make_run(run_id='pending-1', status='PENDING')
|
||||
async_session.add_all([automation, pending_run])
|
||||
await async_session.commit()
|
||||
|
||||
count = await recover_stale_runs(async_session)
|
||||
|
||||
assert count == 0
|
||||
|
||||
await async_session.refresh(pending_run)
|
||||
assert pending_run.status == 'PENDING'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_increments_retry_count(async_session):
|
||||
"""Recovery increments the retry_count."""
|
||||
automation = make_automation()
|
||||
stale_run = make_run(
|
||||
run_id='stale-retry',
|
||||
status='RUNNING',
|
||||
claimed_by='old-executor',
|
||||
heartbeat_at=utc_now() - timedelta(minutes=10),
|
||||
retry_count=2,
|
||||
)
|
||||
async_session.add_all([automation, stale_run])
|
||||
await async_session.commit()
|
||||
|
||||
await recover_stale_runs(async_session)
|
||||
|
||||
await async_session.refresh(stale_run)
|
||||
assert stale_run.retry_count == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _mark_run_failed (error handling)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_run_failed_retries(async_session_factory):
|
||||
"""Failed runs with retries left return to PENDING."""
|
||||
async with async_session_factory() as session:
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='fail-retry', retry_count=0, max_retries=3)
|
||||
session.add_all([automation, run])
|
||||
await session.commit()
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-retry')
|
||||
await _mark_run_failed(run_obj, 'API error', async_session_factory)
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-retry')
|
||||
assert run_obj.status == 'PENDING'
|
||||
assert run_obj.retry_count == 1
|
||||
assert run_obj.error_detail == 'API error'
|
||||
assert run_obj.next_retry_at is not None
|
||||
assert run_obj.claimed_by is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_run_failed_dead_letter(async_session_factory):
|
||||
"""Failed runs that exceed max_retries go to DEAD_LETTER."""
|
||||
async with async_session_factory() as session:
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='fail-dead', retry_count=2, max_retries=3)
|
||||
session.add_all([automation, run])
|
||||
await session.commit()
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-dead')
|
||||
await _mark_run_failed(run_obj, 'Final failure', async_session_factory)
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-dead')
|
||||
assert run_obj.status == 'DEAD_LETTER'
|
||||
assert run_obj.retry_count == 3
|
||||
assert run_obj.error_detail == 'Final failure'
|
||||
assert run_obj.completed_at is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_terminal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_terminal_stopped():
|
||||
assert is_terminal({'status': 'STOPPED'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_error():
|
||||
assert is_terminal({'status': 'ERROR'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_completed():
|
||||
assert is_terminal({'status': 'COMPLETED'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_cancelled():
|
||||
assert is_terminal({'status': 'CANCELLED'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_running():
|
||||
assert is_terminal({'status': 'RUNNING'}) is False
|
||||
|
||||
|
||||
def test_is_terminal_empty():
|
||||
assert is_terminal({}) is False
|
||||
|
||||
|
||||
def test_is_terminal_case_insensitive():
|
||||
assert is_terminal({'status': 'stopped'}) is True
|
||||
assert is_terminal({'status': 'Completed'}) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_automations — None payload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_none_payload(async_session):
|
||||
"""Events with None payload return empty list (data corruption guard)."""
|
||||
event = make_event(source_type='cron')
|
||||
event.payload = None
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: event → run creation → claim
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_integration_event_to_run_to_claim(
|
||||
async_session_factory,
|
||||
):
|
||||
"""Full flow: create event + automation → process_new_events → claim_and_execute_runs.
|
||||
|
||||
Uses a real SQLite database; only the external API client is mocked.
|
||||
"""
|
||||
# 1. Seed an automation and a NEW event
|
||||
async with async_session_factory() as session:
|
||||
automation = make_automation(automation_id='integ-auto')
|
||||
event = make_event(
|
||||
source_type='cron',
|
||||
payload={'automation_id': 'integ-auto'},
|
||||
dedup_key='integ-dedup',
|
||||
)
|
||||
session.add_all([automation, event])
|
||||
await session.commit()
|
||||
event_id = event.id
|
||||
|
||||
# 2. Process inbox — should match and create a PENDING run
|
||||
async with async_session_factory() as session:
|
||||
processed = await process_new_events(session)
|
||||
|
||||
assert processed == 1
|
||||
|
||||
# Verify event is PROCESSED and run was created
|
||||
async with async_session_factory() as session:
|
||||
evt = await session.get(AutomationEvent, event_id)
|
||||
assert evt.status == 'PROCESSED'
|
||||
|
||||
runs = (await session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 1
|
||||
run = runs[0]
|
||||
assert run.automation_id == 'integ-auto'
|
||||
assert run.status == 'PENDING'
|
||||
assert run.event_payload == {'automation_id': 'integ-auto'}
|
||||
|
||||
# 3. Claim the run — mock execute_run to avoid real API calls
|
||||
api_client = AsyncMock()
|
||||
|
||||
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
|
||||
async with async_session_factory() as session:
|
||||
claimed = await claim_and_execute_runs(
|
||||
session, 'executor-integ', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is True
|
||||
|
||||
# 4. Verify the run moved to RUNNING with correct executor
|
||||
async with async_session_factory() as session:
|
||||
runs = (await session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 1
|
||||
run = runs[0]
|
||||
assert run.status == 'RUNNING'
|
||||
assert run.claimed_by == 'executor-integ'
|
||||
assert run.started_at is not None
|
||||
assert run.heartbeat_at is not None
|
||||
185
enterprise/tests/unit/services/test_openhands_api_client.py
Normal file
185
enterprise/tests/unit/services/test_openhands_api_client.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Tests for OpenHandsAPIClient with mocked HTTP responses."""
|
||||
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from services.openhands_api_client import OpenHandsAPIClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_client():
|
||||
client = OpenHandsAPIClient(base_url='http://test-server:3000')
|
||||
yield client
|
||||
# close handled in tests that need it
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# start_conversation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_sends_correct_request(api_client, respx_mock):
|
||||
"""start_conversation sends properly formatted POST with auth header."""
|
||||
automation_file = b'print("hello")'
|
||||
expected_b64 = base64.b64encode(automation_file).decode()
|
||||
|
||||
route = respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
'app_conversation_id': 'conv-123',
|
||||
'status': 'RUNNING',
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
result = await api_client.start_conversation(
|
||||
api_key='sk-oh-test123',
|
||||
automation_file=automation_file,
|
||||
title='Test Automation',
|
||||
event_payload={'automation_id': 'auto-1'},
|
||||
)
|
||||
|
||||
assert route.called
|
||||
request = route.calls[0].request
|
||||
|
||||
assert request.headers['Authorization'] == 'Bearer sk-oh-test123'
|
||||
|
||||
import json
|
||||
|
||||
body = json.loads(request.content)
|
||||
assert body['automation_file'] == expected_b64
|
||||
assert body['trigger'] == 'automation'
|
||||
assert body['title'] == 'Test Automation'
|
||||
assert body['event_payload'] == {'automation_id': 'auto-1'}
|
||||
|
||||
assert result == {'app_conversation_id': 'conv-123', 'status': 'RUNNING'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_without_event_payload(api_client, respx_mock):
|
||||
"""start_conversation works with event_payload=None."""
|
||||
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(200, json={'app_conversation_id': 'conv-456'})
|
||||
)
|
||||
|
||||
result = await api_client.start_conversation(
|
||||
api_key='sk-oh-test',
|
||||
automation_file=b'code',
|
||||
title='Test',
|
||||
event_payload=None,
|
||||
)
|
||||
|
||||
assert result['app_conversation_id'] == 'conv-456'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_http_error(api_client, respx_mock):
|
||||
"""start_conversation raises on HTTP errors."""
|
||||
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(500, json={'error': 'Internal Server Error'})
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc_info:
|
||||
await api_client.start_conversation(
|
||||
api_key='sk-oh-test',
|
||||
automation_file=b'code',
|
||||
title='Test',
|
||||
)
|
||||
|
||||
assert exc_info.value.response.status_code == 500
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_auth_error(api_client, respx_mock):
|
||||
"""start_conversation raises on 401 Unauthorized."""
|
||||
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(401, json={'error': 'Unauthorized'})
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc_info:
|
||||
await api_client.start_conversation(
|
||||
api_key='bad-key',
|
||||
automation_file=b'code',
|
||||
title='Test',
|
||||
)
|
||||
|
||||
assert exc_info.value.response.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_conversation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_returns_data(api_client, respx_mock):
|
||||
"""get_conversation returns the first conversation from the list."""
|
||||
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json=[
|
||||
{
|
||||
'conversation_id': 'conv-123',
|
||||
'status': 'RUNNING',
|
||||
'title': 'My Automation',
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
result = await api_client.get_conversation('sk-oh-test', 'conv-123')
|
||||
|
||||
assert result is not None
|
||||
assert result['conversation_id'] == 'conv-123'
|
||||
assert result['status'] == 'RUNNING'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_returns_none_when_empty(api_client, respx_mock):
|
||||
"""get_conversation returns None when API returns empty list."""
|
||||
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(200, json=[])
|
||||
)
|
||||
|
||||
result = await api_client.get_conversation('sk-oh-test', 'nonexistent')
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_sends_auth_header(api_client, respx_mock):
|
||||
"""get_conversation sends the correct authorization header."""
|
||||
route = respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(200, json=[])
|
||||
)
|
||||
|
||||
await api_client.get_conversation('sk-oh-mykey', 'conv-1')
|
||||
|
||||
assert route.called
|
||||
request = route.calls[0].request
|
||||
assert request.headers['Authorization'] == 'Bearer sk-oh-mykey'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_http_error(api_client, respx_mock):
|
||||
"""get_conversation raises on HTTP errors."""
|
||||
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(503, text='Service Unavailable')
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await api_client.get_conversation('sk-oh-test', 'conv-1')
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# close
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(api_client):
|
||||
"""close() shuts down the HTTP client without errors."""
|
||||
await api_client.close()
|
||||
Reference in New Issue
Block a user