Compare commits

..

3 Commits

Author SHA1 Message Date
openhands
6ec03098ad fix: add integration tests, store prompt in config, remove AST extraction, extract pagination helper, simplify update
- Thread 1: Added integration tests using real SQLite database (aiosqlite)
  that exercise actual SQL queries for list, get, create, delete, pagination
- Thread 3: Store prompt in config JSON column so DB is source of truth,
  not the generated file
- Thread 4: Removed _extract_prompt_from_file (AST extraction) entirely
- Thread 5: Extracted _paginate() helper used by search_automations and
  list_automation_runs
- Thread 6: Simplified update endpoint - reads prompt from config instead
  of parsing file content

Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-11 18:01:25 +00:00
openhands
20e0ebacf0 fix: address PR review feedback for automation CRUD API
- Fix pagination cursor bug: next_page_id now points to rows[limit]
  (first item of next page) instead of rows[limit-1] in both
  search_automations and list_automation_runs
- Simplify update endpoint: use model_dump(exclude_unset=True) to
  extract changed fields and set intersection for regen detection
- Add user isolation security tests: verify user B cannot GET, PATCH,
  DELETE, or search user A's automations
- Move automation_router import to top of saas_server.py with other
  router imports
- Consistent error handling: capture file_store_key before session
  close in delete endpoint
- Safe isoformat() calls: guard against None timestamps in
  _automation_to_response and _run_to_response

Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-11 08:23:20 +00:00
openhands
3230813a95 [Automations Phase 1] Task 2: CRUD API
Implement REST API for creating, reading, updating, and deleting
automations. Simple mode only for Phase 1 (form input → generated file).

New files:
- enterprise/server/routes/automation_models.py: Pydantic request/response models
- enterprise/server/routes/automations.py: FastAPI router with 8 endpoints
- enterprise/storage/automation.py: SQLAlchemy models (stub for Task 1)
- enterprise/storage/automation_event.py: SQLAlchemy model (stub for Task 1)
- enterprise/services/: Config, file generator, event publisher stubs

Endpoints:
- POST /api/v1/automations — Create automation
- GET /api/v1/automations/search — List automations (paginated)
- GET /api/v1/automations/{id} — Get automation
- PATCH /api/v1/automations/{id} — Update automation
- DELETE /api/v1/automations/{id} — Delete automation and runs
- POST /api/v1/automations/{id}/run — Manual trigger
- GET /api/v1/automations/{id}/runs — List runs (paginated)
- GET /api/v1/automations/{id}/runs/{run_id} — Get run detail

Part of RFC: https://github.com/OpenHands/OpenHands/issues/13275

Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-10 18:45:48 +00:00
17 changed files with 1596 additions and 1653 deletions

View File

@@ -1,71 +0,0 @@
"""Entry point for the automation executor.
Usage: python -m run_automation_executor
This runs as a Kubernetes Deployment (long-running). It polls the automation_events
inbox, matches events to automations, claims and executes runs, and monitors
conversation completion.
Environment variables:
OPENHANDS_API_URL Base URL for the V1 API (default: http://openhands-service:3000)
MAX_CONCURRENT_RUNS Max concurrent runs per executor (default: 5)
RUN_TIMEOUT_SECONDS Max time for a single run (default: 7200)
POLL_INTERVAL_SECONDS Fallback poll interval (default: 30)
HEARTBEAT_INTERVAL_SECONDS Heartbeat update interval (default: 60)
"""
import asyncio
import logging
import signal
import sys
logger = logging.getLogger('saas.automation.executor')
def _setup_logging() -> None:
"""Configure logging, deferring to enterprise logger if available."""
try:
from server.logger import setup_all_loggers
setup_all_loggers()
except ImportError:
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(name)s %(levelname)s %(message)s',
stream=sys.stdout,
)
def _install_signal_handlers(loop: asyncio.AbstractEventLoop) -> None:
"""Install signal handlers for graceful shutdown."""
from services.automation_executor import request_shutdown
def _handle_signal(signum: int, _frame: object) -> None:
sig_name = signal.Signals(signum).name
logger.info('Received %s, initiating graceful shutdown...', sig_name)
request_shutdown()
for sig in (signal.SIGTERM, signal.SIGINT):
signal.signal(sig, _handle_signal)
async def main() -> None:
from services.automation_executor import executor_main
await executor_main()
if __name__ == '__main__':
_setup_logging()
loop = asyncio.new_event_loop()
_install_signal_handlers(loop)
logger.info('Starting automation executor')
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
logger.info('Interrupted by user')
finally:
loop.close()
logger.info('Automation executor process exiting')

View File

@@ -27,6 +27,7 @@ from server.middleware import SetAuthCookieMiddleware # noqa: E402
from server.rate_limit import setup_rate_limit_handler # noqa: E402
from server.routes.api_keys import api_router as api_keys_router # noqa: E402
from server.routes.auth import api_router, oauth_router # noqa: E402
from server.routes.automations import automation_router # noqa: E402
from server.routes.billing import billing_router # noqa: E402
from server.routes.email import api_router as email_router # noqa: E402
from server.routes.event_webhook import event_webhook_router # noqa: E402
@@ -139,6 +140,8 @@ if BITBUCKET_DATA_CENTER_HOST:
base_app.include_router(bitbucket_dc_proxy_router)
base_app.include_router(email_router) # Add routes for email management
base_app.include_router(feedback_router) # Add routes for conversation feedback
base_app.include_router(automation_router) # Add routes for automation CRUD
base_app.include_router(
event_webhook_router
) # Add routes for Events in nested runtimes

View File

@@ -0,0 +1,59 @@
"""Pydantic request/response models for automation CRUD API."""
from pydantic import BaseModel, Field
class CreateAutomationRequest(BaseModel):
"""Simple mode (Phase 1): form input → generated file."""
name: str = Field(min_length=1, max_length=200)
schedule: str # 5-field cron expression
timezone: str = 'UTC'
prompt: str = Field(min_length=1)
repository: str | None = None # e.g., "owner/repo"
branch: str | None = None
class UpdateAutomationRequest(BaseModel):
name: str | None = None
schedule: str | None = None
timezone: str | None = None
prompt: str | None = None
repository: str | None = None
branch: str | None = None
enabled: bool | None = None
class AutomationResponse(BaseModel):
id: str
name: str
enabled: bool
trigger_type: str
config: dict
file_url: str | None = None
last_triggered_at: str | None = None
created_at: str
updated_at: str
class AutomationRunResponse(BaseModel):
id: str
automation_id: str
conversation_id: str | None = None
status: str
error_detail: str | None = None
started_at: str | None = None
completed_at: str | None = None
created_at: str
class PaginatedAutomationsResponse(BaseModel):
items: list[AutomationResponse]
total: int
next_page_id: str | None = None
class PaginatedRunsResponse(BaseModel):
items: list[AutomationRunResponse]
total: int
next_page_id: str | None = None

View File

@@ -0,0 +1,465 @@
"""FastAPI router for automation CRUD API (Phase 1: simple mode only)."""
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from services.automation_config import extract_config, validate_config
from services.automation_event_publisher import publish_automation_event
from services.automation_file_generator import generate_automation_file
from sqlalchemy import delete, func, select
from storage.automation import Automation, AutomationRun
from storage.database import a_session_maker
from openhands.core.logger import openhands_logger as logger
from openhands.server.shared import file_store
from openhands.server.user_auth import get_user_id
from .automation_models import (
AutomationResponse,
AutomationRunResponse,
CreateAutomationRequest,
PaginatedAutomationsResponse,
PaginatedRunsResponse,
UpdateAutomationRequest,
)
automation_router = APIRouter(
prefix='/api/v1/automations',
tags=['automations'],
)
FILE_STORE_PREFIX = 'automations'
def _file_store_key(automation_id: str) -> str:
return f'{FILE_STORE_PREFIX}/{automation_id}/automation.py'
def _paginate(rows: list, limit: int, id_attr: str = 'id') -> tuple[list, str | None]:
"""Return (items, next_page_id) from an overfetched result set."""
if len(rows) > limit:
return rows[:limit], getattr(rows[limit], id_attr)
return rows, None
def _automation_to_response(automation: Automation) -> AutomationResponse:
return AutomationResponse(
id=automation.id,
name=automation.name,
enabled=automation.enabled,
trigger_type=automation.trigger_type,
config=automation.config or {},
file_url=None,
last_triggered_at=(
automation.last_triggered_at.isoformat()
if automation.last_triggered_at
else None
),
created_at=automation.created_at.isoformat() if automation.created_at else '',
updated_at=automation.updated_at.isoformat() if automation.updated_at else '',
)
def _run_to_response(run: AutomationRun) -> AutomationRunResponse:
return AutomationRunResponse(
id=run.id,
automation_id=run.automation_id,
conversation_id=run.conversation_id,
status=run.status,
error_detail=run.error_detail,
started_at=run.started_at.isoformat() if run.started_at else None,
completed_at=run.completed_at.isoformat() if run.completed_at else None,
created_at=run.created_at.isoformat() if run.created_at else '',
)
def _generate_and_validate_file(
name: str,
schedule: str,
timezone: str,
prompt: str,
repository: str | None = None,
branch: str | None = None,
) -> tuple[str, dict]:
"""Generate automation file, extract config, validate, and store prompt in config.
Returns (file_content, config_dict).
Raises HTTPException on validation failure.
"""
file_content = generate_automation_file(
name=name,
schedule=schedule,
timezone=timezone,
prompt=prompt,
repository=repository,
branch=branch,
)
config = extract_config(file_content)
try:
validate_config(config)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=f'Invalid automation config: {e}',
)
# Store prompt in config so DB is the source of truth (not the file)
config['prompt'] = prompt
return file_content, config
@automation_router.post('', status_code=status.HTTP_201_CREATED)
async def create_automation(
request: CreateAutomationRequest,
user_id: str = Depends(get_user_id),
) -> AutomationResponse:
"""Create an automation from simple mode input (Phase 1).
Generates a .py file, uploads to object store, stores metadata in DB.
"""
file_content, config = _generate_and_validate_file(
name=request.name,
schedule=request.schedule,
timezone=request.timezone,
prompt=request.prompt,
repository=request.repository,
branch=request.branch,
)
automation_id = uuid.uuid4().hex
key = _file_store_key(automation_id)
try:
file_store.write(key, file_content)
except Exception:
logger.exception('Failed to upload automation file to object store')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to store automation file',
)
automation = Automation(
id=automation_id,
user_id=user_id,
name=request.name,
enabled=True,
config=config,
trigger_type='cron',
file_store_key=key,
)
async with a_session_maker() as session:
session.add(automation)
await session.commit()
await session.refresh(automation)
logger.info(
'Created automation',
extra={'automation_id': automation_id, 'user_id': user_id},
)
return _automation_to_response(automation)
@automation_router.get('/search')
async def search_automations(
user_id: str = Depends(get_user_id),
page_id: Annotated[
str | None,
Query(title='Cursor for pagination (automation ID)'),
] = None,
limit: Annotated[
int,
Query(title='Max results per page', gt=0, le=100),
] = 20,
) -> PaginatedAutomationsResponse:
"""List automations for the current user, paginated."""
async with a_session_maker() as session:
base_filter = select(Automation).where(Automation.user_id == user_id)
# Total count
count_q = select(func.count()).select_from(base_filter.subquery())
total = (await session.execute(count_q)).scalar() or 0
# Paginated query ordered by created_at desc
query = base_filter.order_by(Automation.created_at.desc())
if page_id:
cursor_row = (
await session.execute(
select(Automation.created_at).where(Automation.id == page_id)
)
).scalar()
if cursor_row is not None:
query = query.where(Automation.created_at < cursor_row)
query = query.limit(limit + 1)
result = await session.execute(query)
rows = list(result.scalars().all())
items, next_page_id = _paginate(rows, limit)
return PaginatedAutomationsResponse(
items=[_automation_to_response(a) for a in items],
total=total,
next_page_id=next_page_id,
)
@automation_router.get('/{automation_id}')
async def get_automation(
automation_id: str,
user_id: str = Depends(get_user_id),
) -> AutomationResponse:
"""Get a single automation by ID."""
async with a_session_maker() as session:
result = await session.execute(
select(Automation).where(
Automation.id == automation_id,
Automation.user_id == user_id,
)
)
automation = result.scalars().first()
if not automation:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='Automation not found',
)
return _automation_to_response(automation)
@automation_router.patch('/{automation_id}')
async def update_automation(
automation_id: str,
request: UpdateAutomationRequest,
user_id: str = Depends(get_user_id),
) -> AutomationResponse:
"""Update an automation. Re-generates file if prompt/schedule/timezone/name changed."""
async with a_session_maker() as session:
result = await session.execute(
select(Automation).where(
Automation.id == automation_id,
Automation.user_id == user_id,
)
)
automation = result.scalars().first()
if not automation:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='Automation not found',
)
updates = {
k: v
for k, v in request.model_dump(exclude_unset=True).items()
if v is not None
}
file_regen_fields = {'schedule', 'timezone', 'prompt', 'name'}
needs_regen = bool(updates.keys() & file_regen_fields)
if needs_regen:
current_config = automation.config or {}
current_triggers = current_config.get('triggers', {}).get('cron', {})
# Merge: use request values if provided, else fall back to current config
new_name = updates.get('name', automation.name)
new_schedule = updates.get(
'schedule', current_triggers.get('schedule', '')
)
new_timezone = updates.get(
'timezone', current_triggers.get('timezone', 'UTC')
)
prompt = updates.get('prompt', current_config.get('prompt', ''))
file_content, config = _generate_and_validate_file(
name=new_name,
schedule=new_schedule,
timezone=new_timezone,
prompt=prompt,
repository=updates.get('repository'),
branch=updates.get('branch'),
)
try:
file_store.write(automation.file_store_key, file_content)
except Exception:
logger.exception('Failed to upload updated automation file')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to store updated automation file',
)
automation.config = config
automation.name = new_name
if 'name' in updates and not needs_regen:
automation.name = updates['name']
if 'enabled' in updates:
automation.enabled = updates['enabled']
await session.commit()
await session.refresh(automation)
return _automation_to_response(automation)
@automation_router.delete('/{automation_id}', status_code=status.HTTP_204_NO_CONTENT)
async def delete_automation(
automation_id: str,
user_id: str = Depends(get_user_id),
) -> None:
"""Delete an automation and all its runs."""
async with a_session_maker() as session:
result = await session.execute(
select(Automation).where(
Automation.id == automation_id,
Automation.user_id == user_id,
)
)
automation = result.scalars().first()
if not automation:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='Automation not found',
)
file_key = automation.file_store_key
# Delete runs first
await session.execute(
delete(AutomationRun).where(AutomationRun.automation_id == automation_id)
)
await session.delete(automation)
await session.commit()
# Best-effort cleanup of file store (DB is source of truth)
try:
file_store.delete(file_key)
except Exception:
logger.warning(
'Failed to delete automation file from object store',
extra={'automation_id': automation_id},
)
@automation_router.post('/{automation_id}/run', status_code=status.HTTP_202_ACCEPTED)
async def trigger_manual_run(
automation_id: str,
user_id: str = Depends(get_user_id),
) -> dict:
"""Manually trigger an automation run."""
async with a_session_maker() as session:
result = await session.execute(
select(Automation).where(
Automation.id == automation_id,
Automation.user_id == user_id,
)
)
automation = result.scalars().first()
if not automation:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='Automation not found',
)
dedup_key = f'manual-{automation_id}-{uuid.uuid4().hex}'
await publish_automation_event(
session=session,
source_type='manual',
payload={'automation_id': automation_id},
dedup_key=dedup_key,
)
await session.commit()
return {'status': 'accepted', 'dedup_key': dedup_key}
@automation_router.get('/{automation_id}/runs')
async def list_automation_runs(
automation_id: str,
user_id: str = Depends(get_user_id),
page_id: Annotated[
str | None,
Query(title='Cursor for pagination (run ID)'),
] = None,
limit: Annotated[
int,
Query(title='Max results per page', gt=0, le=100),
] = 20,
) -> PaginatedRunsResponse:
"""List runs for an automation, paginated."""
# Verify ownership
async with a_session_maker() as session:
ownership = await session.execute(
select(Automation.id).where(
Automation.id == automation_id,
Automation.user_id == user_id,
)
)
if not ownership.scalar():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='Automation not found',
)
base_filter = select(AutomationRun).where(
AutomationRun.automation_id == automation_id
)
count_q = select(func.count()).select_from(base_filter.subquery())
total = (await session.execute(count_q)).scalar() or 0
query = base_filter.order_by(AutomationRun.created_at.desc())
if page_id:
cursor_row = (
await session.execute(
select(AutomationRun.created_at).where(AutomationRun.id == page_id)
)
).scalar()
if cursor_row is not None:
query = query.where(AutomationRun.created_at < cursor_row)
query = query.limit(limit + 1)
result = await session.execute(query)
rows = list(result.scalars().all())
items, next_page_id = _paginate(rows, limit)
return PaginatedRunsResponse(
items=[_run_to_response(r) for r in items],
total=total,
next_page_id=next_page_id,
)
@automation_router.get('/{automation_id}/runs/{run_id}')
async def get_automation_run(
automation_id: str,
run_id: str,
user_id: str = Depends(get_user_id),
) -> AutomationRunResponse:
"""Get a single run detail."""
async with a_session_maker() as session:
# Verify ownership of the automation
ownership = await session.execute(
select(Automation.id).where(
Automation.id == automation_id,
Automation.user_id == user_id,
)
)
if not ownership.scalar():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='Automation not found',
)
result = await session.execute(
select(AutomationRun).where(
AutomationRun.id == run_id,
AutomationRun.automation_id == automation_id,
)
)
run = result.scalars().first()
if not run:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='Run not found',
)
return _run_to_response(run)

View File

@@ -0,0 +1,52 @@
"""Automation config extraction and validation.
NOTE: This is a stub for Task 2 (CRUD API) development.
Task 1 (Data Foundation) will provide the full implementation.
"""
import ast
from pydantic import BaseModel, Field
class CronTriggerModel(BaseModel):
schedule: str = Field(pattern=r'^(\S+\s+){4}\S+$')
timezone: str = 'UTC'
class TriggersModel(BaseModel):
cron: CronTriggerModel | None = None
def model_post_init(self, __context: object) -> None:
defined = [k for k in ('cron',) if getattr(self, k) is not None]
if len(defined) != 1:
raise ValueError(f'Exactly one trigger required, got: {defined or "none"}')
class AutomationConfigModel(BaseModel):
name: str = Field(min_length=1, max_length=200)
triggers: TriggersModel
description: str = ''
def extract_config(source: str) -> dict:
"""Extract __config__ dict from a Python automation file using AST."""
tree = ast.parse(source)
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == '__config__':
return ast.literal_eval(node.value)
if isinstance(node, ast.AnnAssign):
if (
isinstance(node.target, ast.Name)
and node.target.id == '__config__'
and node.value is not None
):
return ast.literal_eval(node.value)
raise ValueError('No __config__ dict found in automation file')
def validate_config(config: dict) -> AutomationConfigModel:
"""Validate a __config__ dict. Returns parsed model or raises ValidationError."""
return AutomationConfigModel.model_validate(config)

View File

@@ -0,0 +1,26 @@
"""Automation event publisher.
NOTE: This is a stub for Task 2 (CRUD API) development.
Task 1 (Data Foundation) will provide the full implementation.
"""
from typing import Any
from storage.automation_event import AutomationEvent
async def publish_automation_event(
session: Any,
source_type: str,
payload: dict,
dedup_key: str,
) -> AutomationEvent:
"""Insert a new automation event into the automation_events table."""
event = AutomationEvent(
source_type=source_type,
payload=payload,
dedup_key=dedup_key,
status='NEW',
)
session.add(event)
return event

View File

@@ -1,555 +0,0 @@
"""Automation executor — processes events, claims and executes runs.
The executor is a long-running process with three phases:
1. Process inbox: match NEW events to automations, create PENDING runs
2. Claim and execute: claim PENDING runs, submit to V1 API, heartbeat
3. Stale recovery: recover RUNNING runs with expired heartbeats
"""
import asyncio
import logging
import os
import socket
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from services.openhands_api_client import OpenHandsAPIClient
from sqlalchemy import or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from storage.automation import Automation, AutomationRun
from storage.automation_event import AutomationEvent
logger = logging.getLogger('saas.automation.executor')
# Environment-configurable settings
POLL_INTERVAL_SECONDS = float(os.getenv('POLL_INTERVAL_SECONDS', '30'))
HEARTBEAT_INTERVAL_SECONDS = float(os.getenv('HEARTBEAT_INTERVAL_SECONDS', '60'))
RUN_TIMEOUT_SECONDS = float(os.getenv('RUN_TIMEOUT_SECONDS', '7200'))
MAX_CONCURRENT_RUNS = int(os.getenv('MAX_CONCURRENT_RUNS', '5'))
STALE_THRESHOLD_MINUTES = 5
MAX_EVENTS_PER_BATCH = 50
MAX_RETRIES_DEFAULT = 3
# Terminal conversation statuses
TERMINAL_STATUSES = frozenset({'STOPPED', 'ERROR', 'COMPLETED', 'CANCELLED'})
# Shutdown flag — set by signal handlers
_shutdown_event: asyncio.Event | None = None
# Background task tracking for graceful shutdown
_pending_tasks: set[asyncio.Task] = set()
def utc_now() -> datetime:
return datetime.now(timezone.utc)
def get_shutdown_event() -> asyncio.Event:
global _shutdown_event
if _shutdown_event is None:
_shutdown_event = asyncio.Event()
return _shutdown_event
def should_continue() -> bool:
return not get_shutdown_event().is_set()
def request_shutdown() -> None:
get_shutdown_event().set()
# ---------------------------------------------------------------------------
# Phase 1: Process inbox (event matching)
# ---------------------------------------------------------------------------
async def find_matching_automations(
session: AsyncSession, event: AutomationEvent
) -> list[Automation]:
"""Find automations that match the given event.
Phase 1 supports cron and manual triggers only — both carry
``automation_id`` in the event payload.
"""
source_type = event.source_type
payload = event.payload
if payload is None:
logger.error('Event %s has None payload — possible data corruption', event.id)
return []
if source_type in ('cron', 'manual'):
automation_id = payload.get('automation_id')
if not automation_id:
logger.warning(
'Event %s (source=%s) missing automation_id in payload',
event.id,
source_type,
)
return []
result = await session.execute(
select(Automation).where(
Automation.id == automation_id,
Automation.enabled.is_(True),
)
)
automation = result.scalar_one_or_none()
return [automation] if automation else []
logger.debug('Unhandled event source_type=%s for event %s', source_type, event.id)
return []
async def process_new_events(session: AsyncSession) -> int:
"""Claim NEW events from inbox, match to automations, create runs.
Returns the number of events processed.
"""
result = await session.execute(
select(AutomationEvent)
.where(AutomationEvent.status == 'NEW')
.order_by(AutomationEvent.created_at)
.limit(MAX_EVENTS_PER_BATCH)
.with_for_update(skip_locked=True)
)
events = list(result.scalars())
processed = 0
for event in events:
try:
automations = await find_matching_automations(session, event)
if not automations:
event.status = 'NO_MATCH'
event.processed_at = utc_now()
else:
for automation in automations:
run = AutomationRun(
id=uuid4().hex,
automation_id=automation.id,
event_id=event.id,
status='PENDING',
event_payload=event.payload,
)
session.add(run)
event.status = 'PROCESSED'
event.processed_at = utc_now()
processed += 1
except Exception as e:
logger.exception('Error processing event %s', event.id)
event.status = 'ERROR'
event.error_detail = f'Failed during event matching: {type(e).__name__}: {e}'
event.processed_at = utc_now()
if processed:
await session.commit()
logger.info('Processed %d events', processed)
return processed
# ---------------------------------------------------------------------------
# Phase 2: Claim and execute runs
# ---------------------------------------------------------------------------
async def resolve_user_api_key(session: AsyncSession, user_id: str) -> str | None:
"""Look up a user's API key from the api_keys table.
Returns the first active key found, or None.
"""
from storage.api_key import ApiKey
result = await session.execute(
select(ApiKey.key).where(ApiKey.user_id == user_id).limit(1)
)
row = result.scalar_one_or_none()
return row
async def download_automation_file(file_store_key: str) -> bytes:
"""Download the automation .py file from object storage."""
try:
from openhands.server.shared import file_store
except ImportError as exc:
raise RuntimeError(
'file_store is not available — ensure the enterprise server '
'has been initialised before calling download_automation_file'
) from exc
content = file_store.read(file_store_key)
if isinstance(content, str):
return content.encode('utf-8')
return content
def is_terminal(conversation: dict) -> bool:
"""Check if a conversation has reached a terminal status."""
status = (conversation.get('status') or '').upper()
return status in TERMINAL_STATUSES
async def _prepare_run(
run: AutomationRun,
automation: Automation,
session_factory: object,
) -> tuple[str, bytes]:
"""Resolve the user's API key and download the automation file.
Returns:
(api_key, automation_file) tuple ready for submission.
Raises:
ValueError: If no API key is found.
RuntimeError: If file_store is unavailable.
"""
async with session_factory() as key_session:
api_key = await resolve_user_api_key(key_session, automation.user_id)
if not api_key:
raise ValueError(f'No API key found for user {automation.user_id}')
automation_file = await download_automation_file(automation.file_store_key)
return api_key, automation_file
async def _monitor_conversation(
run: AutomationRun,
conversation_id: str,
api_client: OpenHandsAPIClient,
api_key: str,
session_factory: object,
) -> bool:
"""Monitor a conversation until completion or timeout.
Returns True if completed successfully, False if shutdown requested.
Raises:
TimeoutError: If the run exceeds RUN_TIMEOUT_SECONDS.
"""
start_time = utc_now()
while should_continue():
elapsed = (utc_now() - start_time).total_seconds()
if elapsed > RUN_TIMEOUT_SECONDS:
raise TimeoutError(f'Run exceeded {RUN_TIMEOUT_SECONDS}s timeout')
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
# Update heartbeat
async with session_factory() as session:
run_obj = await session.get(AutomationRun, run.id)
if run_obj:
run_obj.heartbeat_at = utc_now()
await session.commit()
# Check conversation status
conversation = (
await api_client.get_conversation(api_key, conversation_id) or {}
)
if is_terminal(conversation):
return True
return False # shutdown requested
async def _submit_and_monitor(
run: AutomationRun,
api_key: str,
automation_file: bytes,
automation: Automation,
api_client: OpenHandsAPIClient,
session_factory: object,
) -> None:
"""Submit the automation to the V1 API and monitor until completion.
Updates the run's conversation_id, sends heartbeats, and marks the
final status when the conversation reaches a terminal state.
"""
conversation = await api_client.start_conversation(
api_key=api_key,
automation_file=automation_file,
title=f'Automation: {automation.name}',
event_payload=run.event_payload,
)
conversation_id = conversation.get('app_conversation_id') or conversation.get(
'conversation_id'
)
# Persist conversation ID
async with session_factory() as update_session:
run_obj = await update_session.get(AutomationRun, run.id)
if run_obj:
run_obj.conversation_id = conversation_id
await update_session.commit()
# Monitor with heartbeats
completed = await _monitor_conversation(
run, conversation_id, api_client, api_key, session_factory
)
# Update final status
async with session_factory() as final_session:
run_obj = await final_session.get(AutomationRun, run.id)
if run_obj:
if not completed:
# Leave as RUNNING — stale recovery will handle it if needed.
# The conversation may still be running on the API side.
logger.info(
'Run %s left as RUNNING due to executor shutdown', run.id
)
else:
run_obj.status = 'COMPLETED'
run_obj.completed_at = utc_now()
logger.info('Run %s completed successfully', run.id)
await final_session.commit()
async def execute_run(
run: AutomationRun,
automation: Automation,
api_client: OpenHandsAPIClient,
session_factory: object,
) -> None:
"""Execute a single automation run end-to-end.
Orchestrates preparation (API key + file download) and submission/monitoring.
On failure, marks the run for retry or dead-letter.
"""
try:
api_key, automation_file = await _prepare_run(
run, automation, session_factory
)
await _submit_and_monitor(
run, api_key, automation_file, automation, api_client, session_factory
)
except Exception as e:
logger.exception('Run %s failed: %s', run.id, e)
await _mark_run_failed(run, str(e), session_factory)
async def _mark_run_failed(
run: AutomationRun, error: str, session_factory: object
) -> None:
"""Mark a run as FAILED or return to PENDING for retry."""
async with session_factory() as session:
run_obj = await session.get(AutomationRun, run.id)
if not run_obj:
return
run_obj.retry_count = (run_obj.retry_count or 0) + 1
run_obj.error_detail = error
if run_obj.retry_count >= (run_obj.max_retries or MAX_RETRIES_DEFAULT):
run_obj.status = 'DEAD_LETTER'
run_obj.completed_at = utc_now()
logger.error(
'Run %s moved to DEAD_LETTER after %d retries',
run.id,
run_obj.retry_count,
)
else:
run_obj.status = 'PENDING'
run_obj.claimed_by = None
backoff_seconds = 30 * (2 ** (run_obj.retry_count - 1))
run_obj.next_retry_at = utc_now() + timedelta(seconds=backoff_seconds)
logger.warning(
'Run %s returned to PENDING, retry %d/%d in %ds',
run.id,
run_obj.retry_count,
run_obj.max_retries or MAX_RETRIES_DEFAULT,
backoff_seconds,
)
await session.commit()
async def claim_and_execute_runs(
session: AsyncSession,
executor_id: str,
api_client: OpenHandsAPIClient,
session_factory: object,
) -> bool:
"""Claim a PENDING run and start executing it.
Returns True if a run was claimed, False otherwise.
"""
result = await session.execute(
select(AutomationRun)
.where(
AutomationRun.status == 'PENDING',
or_(
AutomationRun.next_retry_at.is_(None),
AutomationRun.next_retry_at <= utc_now(),
),
)
.order_by(AutomationRun.created_at)
.limit(1)
.with_for_update(skip_locked=True)
)
run = result.scalar_one_or_none()
if not run:
return False
# Claim the run
run.status = 'RUNNING'
run.claimed_by = executor_id
run.claimed_at = utc_now()
run.heartbeat_at = utc_now()
run.started_at = utc_now()
await session.commit()
# Load automation for the run
auto_result = await session.execute(
select(Automation).where(Automation.id == run.automation_id)
)
automation = auto_result.scalar_one_or_none()
if not automation:
logger.error('Automation %s not found for run %s', run.automation_id, run.id)
await _mark_run_failed(
run, f'Automation {run.automation_id} not found', session_factory
)
return True
# Execute in background (long-running) with task tracking
task = asyncio.create_task(
execute_run(run, automation, api_client, session_factory),
name=f'execute-run-{run.id}',
)
_pending_tasks.add(task)
task.add_done_callback(_pending_tasks.discard)
logger.info(
'Claimed run %s (automation=%s) by executor %s',
run.id,
run.automation_id,
executor_id,
)
return True
# ---------------------------------------------------------------------------
# Phase 3: Stale run recovery
# ---------------------------------------------------------------------------
async def recover_stale_runs(session: AsyncSession) -> int:
"""Mark RUNNING runs with expired heartbeats as PENDING for retry.
Returns the number of recovered runs.
"""
stale_threshold = utc_now() - timedelta(minutes=STALE_THRESHOLD_MINUTES)
timeout_threshold = utc_now() - timedelta(seconds=RUN_TIMEOUT_SECONDS)
# Recover stale runs (heartbeat expired)
result = await session.execute(
update(AutomationRun)
.where(
AutomationRun.status == 'RUNNING',
AutomationRun.heartbeat_at < stale_threshold,
AutomationRun.heartbeat_at >= timeout_threshold,
)
.values(
status='PENDING',
claimed_by=None,
retry_count=AutomationRun.retry_count + 1,
next_retry_at=utc_now() + timedelta(seconds=30),
)
.returning(AutomationRun.id)
)
recovered_rows = result.fetchall()
# Mark truly timed-out runs as DEAD_LETTER
timeout_result = await session.execute(
update(AutomationRun)
.where(
AutomationRun.status == 'RUNNING',
AutomationRun.heartbeat_at < timeout_threshold,
)
.values(
status='DEAD_LETTER',
error_detail='Run exceeded timeout',
completed_at=utc_now(),
)
.returning(AutomationRun.id)
)
timed_out_rows = timeout_result.fetchall()
await session.commit()
recovered_count = len(recovered_rows)
timed_out_count = len(timed_out_rows)
if recovered_count:
logger.warning('Recovered %d stale automation runs', recovered_count)
if timed_out_count:
logger.warning(
'Marked %d automation runs as DEAD_LETTER (timeout)', timed_out_count
)
return recovered_count + timed_out_count
# ---------------------------------------------------------------------------
# Main executor loop
# ---------------------------------------------------------------------------
async def executor_main(session_factory: object | None = None) -> None:
"""Main executor loop.
Args:
session_factory: Async context manager that yields AsyncSession instances.
If None, uses the default ``a_session_maker`` from database module.
"""
if session_factory is None:
from storage.database import a_session_maker
session_factory = a_session_maker
executor_id = f'executor-{socket.gethostname()}-{os.getpid()}'
api_url = os.getenv('OPENHANDS_API_URL', 'http://openhands-service:3000')
api_client = OpenHandsAPIClient(base_url=api_url)
logger.info(
'Automation executor %s starting (api_url=%s, poll=%ss, heartbeat=%ss)',
executor_id,
api_url,
POLL_INTERVAL_SECONDS,
HEARTBEAT_INTERVAL_SECONDS,
)
try:
while should_continue():
try:
async with session_factory() as session:
await process_new_events(session)
async with session_factory() as session:
await claim_and_execute_runs(
session, executor_id, api_client, session_factory
)
async with session_factory() as session:
await recover_stale_runs(session)
except Exception:
logger.exception('Error in executor main loop iteration')
# Wait for next poll interval (or early wakeup on shutdown)
try:
await asyncio.wait_for(
get_shutdown_event().wait(),
timeout=POLL_INTERVAL_SECONDS,
)
except asyncio.TimeoutError:
pass # Normal — poll interval elapsed
finally:
if _pending_tasks:
logger.info(
'Waiting for %d running tasks to complete...', len(_pending_tasks)
)
await asyncio.gather(*_pending_tasks, return_exceptions=True)
await api_client.close()
logger.info('Automation executor %s shut down', executor_id)

View File

@@ -0,0 +1,49 @@
"""Automation file generator for simple mode (Phase 1).
NOTE: This is a stub for Task 2 (CRUD API) development.
Task 1 (Data Foundation) will provide the full implementation.
"""
import json
PROMPT_TEMPLATE = '''\
"""{name} — auto-generated from form input."""
__config__ = {config_json}
import os
from openhands.sdk import LLM, Conversation
from openhands.tools.preset.default import get_default_agent
llm = LLM(
model=os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-5-20250929"),
api_key=os.getenv("LLM_API_KEY"),
base_url=os.getenv("LLM_BASE_URL"),
)
agent = get_default_agent(llm=llm, cli_mode=True)
conversation = Conversation(agent=agent, workspace=os.getcwd())
conversation.send_message({prompt!r})
conversation.run()
'''
def generate_automation_file(
name: str,
schedule: str,
timezone: str,
prompt: str,
repository: str | None = None,
branch: str | None = None,
) -> str:
"""Generate a Python automation file from form input."""
config: dict = {
'name': name,
'triggers': {'cron': {'schedule': schedule, 'timezone': timezone}},
}
return PROMPT_TEMPLATE.format(
name=name,
config_json=json.dumps(config, indent=4),
prompt=prompt,
)

View File

@@ -1,93 +0,0 @@
"""HTTP client for the main OpenHands V1 API (internal cluster calls).
Used by the automation executor to create and monitor conversations
in the main OpenHands server.
"""
import base64
import logging
import httpx
logger = logging.getLogger('saas.automation.api_client')
def _raise_with_body(resp: httpx.Response) -> None:
"""Call raise_for_status, enriching the error with the response body."""
try:
resp.raise_for_status()
except httpx.HTTPStatusError as e:
error_body = resp.text[:500] if resp.text else 'no response body'
raise httpx.HTTPStatusError(
f'{e.args[0]} — Response: {error_body}',
request=e.request,
response=e.response,
) from e
class OpenHandsAPIClient:
"""Async HTTP client for the OpenHands V1 API."""
def __init__(self, base_url: str = 'http://openhands-service:3000'):
self.base_url = base_url.rstrip('/')
self.client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0)
async def start_conversation(
self,
api_key: str,
automation_file: bytes,
title: str,
event_payload: dict | None = None,
) -> dict:
"""Submit an SDK script for sandboxed execution via V1 API.
Args:
api_key: User's API key for authentication.
automation_file: Raw bytes of the .py automation script.
title: Display title for the conversation.
event_payload: Optional trigger event data (injected as env var).
Returns:
Parsed JSON response containing conversation details.
Raises:
httpx.HTTPStatusError: If the API returns a non-2xx status.
"""
resp = await self.client.post(
'/api/v1/app-conversations',
json={
'automation_file': base64.b64encode(automation_file).decode(),
'trigger': 'automation',
'title': title,
'event_payload': event_payload,
},
headers={'Authorization': f'Bearer {api_key}'},
)
_raise_with_body(resp)
return resp.json()
async def get_conversation(self, api_key: str, conversation_id: str) -> dict | None:
"""Get conversation status.
Args:
api_key: User's API key for authentication.
conversation_id: The conversation ID to look up.
Returns:
Conversation data dict, or None if not found.
Raises:
httpx.HTTPStatusError: If the API returns a non-2xx status.
"""
resp = await self.client.get(
'/api/v1/app-conversations',
params={'ids': [conversation_id]},
headers={'Authorization': f'Bearer {api_key}'},
)
_raise_with_body(resp)
conversations = resp.json()
return conversations[0] if conversations else None
async def close(self) -> None:
"""Close the underlying HTTP client."""
await self.client.aclose()

View File

@@ -1,77 +1,47 @@
"""SQLAlchemy models for automations and automation runs.
"""SQLAlchemy models for automations.
Stub for Task 1 (Data Foundation). These models will be replaced when Task 1
is merged into automations-phase1.
NOTE: This is a stub for Task 2 (CRUD API) development.
Task 1 (Data Foundation) will provide the full implementation.
"""
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
Text,
text,
)
from sqlalchemy.orm import relationship
from sqlalchemy.types import JSON
from sqlalchemy import JSON, Boolean, Column, DateTime, String
from sqlalchemy.sql import func
from storage.base import Base
class Automation(Base):
class Automation(Base): # type: ignore
__tablename__ = 'automations'
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False, index=True)
org_id = Column(String, nullable=True, index=True)
name = Column(String, nullable=False)
enabled = Column(Boolean, nullable=False, server_default=text('true'))
enabled = Column(Boolean, nullable=False, default=True)
config = Column(JSON, nullable=False)
trigger_type = Column(String, nullable=False)
file_store_key = Column(String, nullable=False)
last_triggered_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
server_default=func.now(),
onupdate=func.now(),
)
runs = relationship('AutomationRun', back_populates='automation')
class AutomationRun(Base):
class AutomationRun(Base): # type: ignore
__tablename__ = 'automation_runs'
id = Column(String, primary_key=True)
automation_id = Column(
String, ForeignKey('automations.id', ondelete='CASCADE'), nullable=False
)
event_id = Column(Integer, ForeignKey('automation_events.id'), nullable=True)
automation_id = Column(String, nullable=False, index=True)
conversation_id = Column(String, nullable=True)
status = Column(String, nullable=False, server_default=text("'PENDING'"))
claimed_by = Column(String, nullable=True)
claimed_at = Column(DateTime(timezone=True), nullable=True)
heartbeat_at = Column(DateTime(timezone=True), nullable=True)
retry_count = Column(Integer, nullable=False, server_default=text('0'))
max_retries = Column(Integer, nullable=False, server_default=text('3'))
next_retry_at = Column(DateTime(timezone=True), nullable=True)
event_payload = Column(JSON, nullable=True)
error_detail = Column(Text, nullable=True)
status = Column(String, nullable=False, default='PENDING')
error_detail = Column(String, nullable=True)
started_at = Column(DateTime(timezone=True), nullable=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
DateTime(timezone=True), nullable=False, server_default=func.now()
)
automation = relationship('Automation', back_populates='runs')

View File

@@ -1,27 +1,25 @@
"""SQLAlchemy model for automation events (the inbox).
"""SQLAlchemy model for automation events.
Stub for Task 1 (Data Foundation). This model will be replaced when Task 1
is merged into automations-phase1.
NOTE: This is a stub for Task 2 (CRUD API) development.
Task 1 (Data Foundation) will provide the full implementation.
"""
from sqlalchemy import Column, DateTime, Integer, String, Text, text
from sqlalchemy.types import JSON
from sqlalchemy import JSON, BigInteger, Column, DateTime, String
from sqlalchemy.sql import func
from storage.base import Base
class AutomationEvent(Base):
class AutomationEvent(Base): # type: ignore
__tablename__ = 'automation_events'
id = Column(Integer, primary_key=True, autoincrement=True)
id = Column(BigInteger, primary_key=True, autoincrement=True)
source_type = Column(String, nullable=False)
payload = Column(JSON, nullable=False)
metadata_ = Column('metadata', JSON, nullable=True)
dedup_key = Column(String, nullable=False, unique=True)
status = Column(String, nullable=False, server_default=text("'NEW'"))
error_detail = Column(Text, nullable=True)
status = Column(String, nullable=False, default='NEW')
error_detail = Column(String, nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
DateTime(timezone=True), nullable=False, server_default=func.now()
)
processed_at = Column(DateTime(timezone=True), nullable=True)

View File

@@ -0,0 +1,653 @@
"""Unit tests for automation CRUD API routes."""
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI, status
from fastapi.testclient import TestClient
from server.routes.automations import automation_router
from openhands.server.user_auth import get_user_id
TEST_USER_ID = str(uuid.uuid4())
OTHER_USER_ID = str(uuid.uuid4())
def _make_automation(
automation_id: str | None = None,
user_id: str = TEST_USER_ID,
name: str = 'Test Automation',
enabled: bool = True,
trigger_type: str = 'cron',
schedule: str = '0 9 * * 5',
timezone: str = 'UTC',
file_store_key: str | None = None,
):
auto_id = automation_id or uuid.uuid4().hex
mock = MagicMock()
mock.id = auto_id
mock.user_id = user_id
mock.name = name
mock.enabled = enabled
mock.trigger_type = trigger_type
mock.config = {
'name': name,
'triggers': {'cron': {'schedule': schedule, 'timezone': timezone}},
}
mock.file_store_key = file_store_key or f'automations/{auto_id}/automation.py'
mock.last_triggered_at = None
mock.created_at = datetime(2026, 1, 1, tzinfo=UTC)
mock.updated_at = datetime(2026, 1, 1, tzinfo=UTC)
return mock
def _make_run(
run_id: str | None = None,
automation_id: str = 'auto-1',
conversation_id: str | None = None,
run_status: str = 'PENDING',
):
rid = run_id or uuid.uuid4().hex
mock = MagicMock()
mock.id = rid
mock.automation_id = automation_id
mock.conversation_id = conversation_id
mock.status = run_status
mock.error_detail = None
mock.started_at = None
mock.completed_at = None
mock.created_at = datetime(2026, 1, 2, tzinfo=UTC)
return mock
# --- Helpers to mock async DB sessions ---
def _mock_session_with_results(results_by_call):
"""Create a mock async session that returns preconfigured results.
results_by_call: list of values; each session.execute() returns
the next value wrapped in a mock result.
"""
call_index = [0]
session = AsyncMock()
async def _execute(stmt):
idx = call_index[0]
call_index[0] += 1
val = results_by_call[idx] if idx < len(results_by_call) else None
result_mock = MagicMock()
if isinstance(val, list):
result_mock.scalars.return_value.all.return_value = val
result_mock.scalars.return_value.first.return_value = (
val[0] if val else None
)
result_mock.scalar.return_value = len(val)
elif val is None:
result_mock.scalars.return_value.first.return_value = None
result_mock.scalars.return_value.all.return_value = []
result_mock.scalar.return_value = None
else:
result_mock.scalars.return_value.first.return_value = val
result_mock.scalar.return_value = val
return result_mock
session.execute = AsyncMock(side_effect=_execute)
session.commit = AsyncMock()
session.refresh = AsyncMock()
session.delete = AsyncMock()
session.add = MagicMock()
return session
@asynccontextmanager
async def _session_ctx(session):
yield session
# --- Fixtures ---
@pytest.fixture
def mock_app():
"""Create a test FastAPI app with automation routes and mocked auth."""
app = FastAPI()
app.include_router(automation_router)
def mock_get_user_id():
return TEST_USER_ID
app.dependency_overrides[get_user_id] = mock_get_user_id
return app
@pytest.fixture
def client(mock_app):
return TestClient(mock_app)
# --- Test: POST /api/v1/automations ---
class TestCreateAutomation:
def test_create_success(self, client):
"""POST with valid input → 201 with AutomationResponse."""
mock_session = _mock_session_with_results([])
async def fake_refresh(obj):
obj.created_at = datetime(2026, 1, 1, tzinfo=UTC)
obj.updated_at = datetime(2026, 1, 1, tzinfo=UTC)
obj.last_triggered_at = None
mock_session.refresh = AsyncMock(side_effect=fake_refresh)
with (
patch(
'server.routes.automations.generate_automation_file',
return_value='__config__ = {"name": "Test", "triggers": {"cron": {"schedule": "0 9 * * 5"}}}',
),
patch(
'server.routes.automations.extract_config',
return_value={
'name': 'Test',
'triggers': {'cron': {'schedule': '0 9 * * 5'}},
},
),
patch('server.routes.automations.validate_config'),
patch('server.routes.automations.file_store') as mock_fs,
patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
),
):
response = client.post(
'/api/v1/automations',
json={
'name': 'Test',
'schedule': '0 9 * * 5',
'prompt': 'Summarize PRs',
},
)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data['name'] == 'Test'
assert data['enabled'] is True
assert data['trigger_type'] == 'cron'
assert 'id' in data
mock_fs.write.assert_called_once()
def test_create_missing_name(self, client):
"""POST with missing name → 422."""
response = client.post(
'/api/v1/automations',
json={'schedule': '0 9 * * 5', 'prompt': 'Test'},
)
assert response.status_code == 422
def test_create_empty_name(self, client):
"""POST with empty name → 422."""
response = client.post(
'/api/v1/automations',
json={'name': '', 'schedule': '0 9 * * 5', 'prompt': 'Test'},
)
assert response.status_code == 422
def test_create_missing_prompt(self, client):
"""POST with missing prompt → 422."""
response = client.post(
'/api/v1/automations',
json={'name': 'Test', 'schedule': '0 9 * * 5'},
)
assert response.status_code == 422
def test_create_invalid_config_rejected(self, client):
"""POST where validate_config raises → 422."""
with (
patch(
'server.routes.automations.generate_automation_file',
return_value='__config__ = {}',
),
patch(
'server.routes.automations.extract_config',
return_value={},
),
patch(
'server.routes.automations.validate_config',
side_effect=ValueError('Invalid cron expression'),
),
):
response = client.post(
'/api/v1/automations',
json={
'name': 'Bad Cron',
'schedule': 'not-a-cron',
'prompt': 'Test',
},
)
assert response.status_code == 422
assert 'Invalid automation config' in response.json()['detail']
# --- Test: GET /api/v1/automations/search ---
class TestSearchAutomations:
def test_list_returns_user_automations(self, client):
"""GET /search → returns only current user's automations."""
a1 = _make_automation(name='Auto 1')
a2 = _make_automation(name='Auto 2')
# Session calls: count query → 2, paginated query → [a1, a2]
mock_session = _mock_session_with_results([2, [a1, a2]])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.get('/api/v1/automations/search')
assert response.status_code == 200
data = response.json()
assert data['total'] == 2
assert len(data['items']) == 2
assert data['items'][0]['name'] == 'Auto 1'
def test_list_empty(self, client):
"""GET /search when no automations → empty list."""
mock_session = _mock_session_with_results([0, []])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.get('/api/v1/automations/search')
assert response.status_code == 200
data = response.json()
assert data['total'] == 0
assert data['items'] == []
assert data['next_page_id'] is None
# --- Test: GET /api/v1/automations/{id} ---
class TestGetAutomation:
def test_get_existing(self, client):
"""GET existing automation → 200."""
auto = _make_automation(automation_id='auto-123')
mock_session = _mock_session_with_results([auto])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.get('/api/v1/automations/auto-123')
assert response.status_code == 200
assert response.json()['id'] == 'auto-123'
def test_get_nonexistent(self, client):
"""GET non-existent automation → 404."""
mock_session = _mock_session_with_results([None])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.get('/api/v1/automations/does-not-exist')
assert response.status_code == status.HTTP_404_NOT_FOUND
# --- Test: PATCH /api/v1/automations/{id} ---
class TestUpdateAutomation:
def test_update_name_and_enabled(self, client):
"""PATCH with name + enabled → updates fields, returns 200."""
auto = _make_automation(automation_id='auto-123')
mock_session = _mock_session_with_results([auto])
async def fake_refresh(obj):
obj.name = 'Updated Name'
obj.enabled = False
obj.id = 'auto-123'
obj.trigger_type = 'cron'
obj.config = auto.config
obj.last_triggered_at = None
obj.created_at = datetime(2026, 1, 1, tzinfo=UTC)
obj.updated_at = datetime(2026, 1, 2, tzinfo=UTC)
mock_session.refresh = AsyncMock(side_effect=fake_refresh)
with (
patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
),
patch(
'server.routes.automations.generate_automation_file',
return_value='__config__ = {}',
),
patch(
'server.routes.automations.extract_config',
return_value=auto.config,
),
patch('server.routes.automations.validate_config'),
patch('server.routes.automations.file_store') as mock_fs,
):
mock_fs.read.return_value = (
'conversation.send_message("old prompt")\nconversation.run()'
)
response = client.patch(
'/api/v1/automations/auto-123',
json={'name': 'Updated Name', 'enabled': False},
)
assert response.status_code == 200
data = response.json()
assert data['name'] == 'Updated Name'
assert data['enabled'] is False
def test_update_nonexistent(self, client):
"""PATCH non-existent → 404."""
mock_session = _mock_session_with_results([None])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.patch(
'/api/v1/automations/nope',
json={'name': 'X'},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_update_prompt_regenerates_file(self, client):
"""PATCH with new prompt → re-generates file and uploads."""
auto = _make_automation(automation_id='auto-123')
mock_session = _mock_session_with_results([auto])
async def fake_refresh(obj):
obj.id = 'auto-123'
obj.name = auto.name
obj.enabled = auto.enabled
obj.trigger_type = 'cron'
obj.config = auto.config
obj.last_triggered_at = None
obj.created_at = datetime(2026, 1, 1, tzinfo=UTC)
obj.updated_at = datetime(2026, 1, 2, tzinfo=UTC)
mock_session.refresh = AsyncMock(side_effect=fake_refresh)
with (
patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
),
patch(
'server.routes.automations.generate_automation_file',
return_value='__config__ = {}',
) as mock_gen,
patch(
'server.routes.automations.extract_config',
return_value=auto.config,
),
patch('server.routes.automations.validate_config'),
patch('server.routes.automations.file_store') as mock_fs,
):
response = client.patch(
'/api/v1/automations/auto-123',
json={'prompt': 'New prompt text'},
)
assert response.status_code == 200
mock_gen.assert_called_once()
mock_fs.write.assert_called_once()
# --- Test: DELETE /api/v1/automations/{id} ---
class TestDeleteAutomation:
def test_delete_existing(self, client):
"""DELETE existing → 204."""
auto = _make_automation(automation_id='auto-123')
# First execute: select automation, second: delete runs
mock_session = _mock_session_with_results([auto, None])
with (
patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
),
patch('server.routes.automations.file_store'),
):
response = client.delete('/api/v1/automations/auto-123')
assert response.status_code == status.HTTP_204_NO_CONTENT
def test_delete_nonexistent(self, client):
"""DELETE non-existent → 404."""
mock_session = _mock_session_with_results([None])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.delete('/api/v1/automations/nope')
assert response.status_code == status.HTTP_404_NOT_FOUND
# --- Test: POST /api/v1/automations/{id}/run ---
class TestManualTrigger:
def test_manual_trigger_success(self, client):
"""POST .../run on existing automation → 202."""
auto = _make_automation(automation_id='auto-123')
mock_session = _mock_session_with_results([auto])
with (
patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
),
patch(
'server.routes.automations.publish_automation_event',
new_callable=AsyncMock,
) as mock_pub,
):
response = client.post('/api/v1/automations/auto-123/run')
assert response.status_code == status.HTTP_202_ACCEPTED
data = response.json()
assert data['status'] == 'accepted'
assert 'dedup_key' in data
assert data['dedup_key'].startswith('manual-auto-123-')
mock_pub.assert_called_once()
def test_manual_trigger_nonexistent(self, client):
"""POST .../run on non-existent → 404."""
mock_session = _mock_session_with_results([None])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.post('/api/v1/automations/nope/run')
assert response.status_code == status.HTTP_404_NOT_FOUND
# --- Test: GET /api/v1/automations/{id}/runs ---
class TestListRuns:
def test_list_runs_success(self, client):
"""GET .../runs → paginated list."""
r1 = _make_run(run_id='run-1', automation_id='auto-123')
r2 = _make_run(run_id='run-2', automation_id='auto-123')
# Calls: ownership check → 'auto-123', count → 2, paginated → [r1, r2]
mock_session = _mock_session_with_results(['auto-123', 2, [r1, r2]])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.get('/api/v1/automations/auto-123/runs')
assert response.status_code == 200
data = response.json()
assert data['total'] == 2
assert len(data['items']) == 2
def test_list_runs_automation_not_found(self, client):
"""GET .../runs for non-existent automation → 404."""
mock_session = _mock_session_with_results([None])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.get('/api/v1/automations/nope/runs')
assert response.status_code == status.HTTP_404_NOT_FOUND
# --- Test: GET /api/v1/automations/{id}/runs/{run_id} ---
class TestGetRun:
def test_get_run_success(self, client):
"""GET single run → 200."""
run = _make_run(run_id='run-1', automation_id='auto-123')
# Calls: ownership check → 'auto-123', select run → run
mock_session = _mock_session_with_results(['auto-123', run])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.get('/api/v1/automations/auto-123/runs/run-1')
assert response.status_code == 200
assert response.json()['id'] == 'run-1'
def test_get_run_not_found(self, client):
"""GET non-existent run → 404."""
mock_session = _mock_session_with_results(['auto-123', None])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.get('/api/v1/automations/auto-123/runs/nope')
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_get_run_automation_not_found(self, client):
"""GET run for non-existent automation → 404."""
mock_session = _mock_session_with_results([None])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = client.get('/api/v1/automations/nope/runs/run-1')
assert response.status_code == status.HTTP_404_NOT_FOUND
# --- Test: User Isolation (security) ---
class TestUserIsolation:
"""Verify that user A cannot access, update, or delete user B's automations.
The routes filter by user_id from the auth dependency, so automations owned by
another user should never be returned (the DB query uses WHERE user_id = <caller>).
We simulate this by having the mock session return None for cross-user lookups.
"""
@pytest.fixture
def other_user_app(self):
"""App configured to authenticate as OTHER_USER_ID."""
app = FastAPI()
app.include_router(automation_router)
def mock_get_other_user_id():
return OTHER_USER_ID
app.dependency_overrides[get_user_id] = mock_get_other_user_id
return app
@pytest.fixture
def other_client(self, other_user_app):
return TestClient(other_user_app)
def test_cannot_get_other_users_automation(self, other_client):
"""User B cannot GET user A's automation → 404."""
# The query filters by user_id=OTHER_USER_ID, so it won't find user A's row
mock_session = _mock_session_with_results([None])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = other_client.get('/api/v1/automations/auto-owned-by-a')
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_cannot_update_other_users_automation(self, other_client):
"""User B cannot PATCH user A's automation → 404."""
mock_session = _mock_session_with_results([None])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = other_client.patch(
'/api/v1/automations/auto-owned-by-a',
json={'name': 'Hijacked'},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_cannot_delete_other_users_automation(self, other_client):
"""User B cannot DELETE user A's automation → 404."""
mock_session = _mock_session_with_results([None])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = other_client.delete('/api/v1/automations/auto-owned-by-a')
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_search_returns_empty_for_other_user(self, other_client):
"""User B's search returns empty even if user A has automations."""
# count=0, rows=[]
mock_session = _mock_session_with_results([0, []])
with patch(
'server.routes.automations.a_session_maker',
return_value=_session_ctx(mock_session),
):
response = other_client.get('/api/v1/automations/search')
assert response.status_code == 200
data = response.json()
assert data['total'] == 0
assert data['items'] == []

View File

@@ -0,0 +1,264 @@
"""Integration tests for automation CRUD API using a real in-memory SQLite database.
These tests exercise actual SQL queries (list, get, create+verify, pagination, delete)
rather than mocking the database layer.
"""
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from server.routes.automations import automation_router
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from storage.automation import Automation, AutomationRun
from openhands.app_server.utils.sql_utils import Base
from openhands.server.user_auth import get_user_id
TEST_USER_ID = 'integration-test-user'
OTHER_USER_ID = 'other-user'
@pytest.fixture
async def db_engine():
engine = create_async_engine('sqlite+aiosqlite://', echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.fixture
def session_maker(db_engine):
return async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
@pytest.fixture
def app(session_maker):
"""FastAPI app wired to a real SQLite database."""
app = FastAPI()
app.include_router(automation_router)
app.dependency_overrides[get_user_id] = lambda: TEST_USER_ID
@asynccontextmanager
async def _session_ctx():
async with session_maker() as session:
yield session
with patch('server.routes.automations.a_session_maker', _session_ctx):
yield app
@pytest.fixture
async def client(app):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url='http://test') as c:
yield c
def _make_automation_obj(
user_id: str = TEST_USER_ID,
name: str = 'Test Auto',
created_at: datetime | None = None,
**kwargs,
) -> Automation:
return Automation(
id=kwargs.get('automation_id', uuid.uuid4().hex),
user_id=user_id,
name=name,
enabled=kwargs.get('enabled', True),
config=kwargs.get(
'config',
{
'name': name,
'triggers': {'cron': {'schedule': '0 9 * * 5', 'timezone': 'UTC'}},
'prompt': 'Do something',
},
),
trigger_type='cron',
file_store_key=kwargs.get('file_store_key', f'automations/{uuid.uuid4().hex}/automation.py'),
created_at=created_at or datetime.now(UTC),
updated_at=created_at or datetime.now(UTC),
)
# ---------- Test: list (search) returns correct results ----------
@pytest.mark.asyncio
async def test_search_returns_user_automations(client, session_maker):
"""GET /search returns only automations owned by the requesting user."""
async with session_maker() as session:
a1 = _make_automation_obj(name='Auto A', created_at=datetime(2026, 1, 1, tzinfo=UTC))
a2 = _make_automation_obj(name='Auto B', created_at=datetime(2026, 1, 2, tzinfo=UTC))
a_other = _make_automation_obj(user_id=OTHER_USER_ID, name='Other User Auto')
session.add_all([a1, a2, a_other])
await session.commit()
response = await client.get('/api/v1/automations/search')
assert response.status_code == 200
data = response.json()
assert data['total'] == 2
assert len(data['items']) == 2
names = {item['name'] for item in data['items']}
assert names == {'Auto A', 'Auto B'}
# ---------- Test: get returns the right object ----------
@pytest.mark.asyncio
async def test_get_returns_correct_automation(client, session_maker):
"""GET /{id} returns the correct automation by ID."""
auto_id = uuid.uuid4().hex
async with session_maker() as session:
auto = _make_automation_obj(automation_id=auto_id, name='Specific Auto')
session.add(auto)
await session.commit()
response = await client.get(f'/api/v1/automations/{auto_id}')
assert response.status_code == 200
data = response.json()
assert data['id'] == auto_id
assert data['name'] == 'Specific Auto'
@pytest.mark.asyncio
async def test_get_nonexistent_returns_404(client):
"""GET /{id} for non-existent automation returns 404."""
response = await client.get('/api/v1/automations/does-not-exist')
assert response.status_code == 404
# ---------- Test: create + verify in DB ----------
@pytest.mark.asyncio
async def test_create_stores_in_db(client, session_maker):
"""POST creates an automation and it's readable from the database."""
mock_file_store = MagicMock()
config = {
'name': 'New Auto',
'triggers': {'cron': {'schedule': '0 9 * * 5', 'timezone': 'UTC'}},
}
with (
patch(
'server.routes.automations.generate_automation_file',
return_value='__config__ = {}',
),
patch('server.routes.automations.extract_config', return_value=config),
patch('server.routes.automations.validate_config'),
patch('server.routes.automations.file_store', mock_file_store),
):
response = await client.post(
'/api/v1/automations',
json={
'name': 'New Auto',
'schedule': '0 9 * * 5',
'prompt': 'Summarize PRs',
},
)
assert response.status_code == 201
data = response.json()
created_id = data['id']
# Verify it's in the DB via the GET endpoint
get_response = await client.get(f'/api/v1/automations/{created_id}')
assert get_response.status_code == 200
assert get_response.json()['name'] == 'New Auto'
# Verify prompt is stored in config
assert get_response.json()['config'].get('prompt') == 'Summarize PRs'
# ---------- Test: delete actually deletes ----------
@pytest.mark.asyncio
async def test_delete_removes_from_db(client, session_maker):
"""DELETE removes the automation from the database."""
auto_id = uuid.uuid4().hex
async with session_maker() as session:
auto = _make_automation_obj(automation_id=auto_id, name='To Delete')
session.add(auto)
await session.commit()
mock_file_store = MagicMock()
with patch('server.routes.automations.file_store', mock_file_store):
response = await client.delete(f'/api/v1/automations/{auto_id}')
assert response.status_code == 204
# Verify it's gone
get_response = await client.get(f'/api/v1/automations/{auto_id}')
assert get_response.status_code == 404
# ---------- Test: pagination actually works ----------
@pytest.mark.asyncio
async def test_pagination_returns_correct_pages(client, session_maker):
"""Pagination with limit returns correct page sizes and next_page_id."""
base_time = datetime(2026, 1, 1, tzinfo=UTC)
async with session_maker() as session:
for i in range(5):
auto = _make_automation_obj(
name=f'Auto {i}',
created_at=base_time + timedelta(hours=i),
)
session.add(auto)
await session.commit()
# First page with limit=2
response = await client.get('/api/v1/automations/search?limit=2')
assert response.status_code == 200
data = response.json()
assert data['total'] == 5
assert len(data['items']) == 2
assert data['next_page_id'] is not None
# Second page using cursor — should return remaining items before cursor
next_id = data['next_page_id']
response2 = await client.get(f'/api/v1/automations/search?limit=2&page_id={next_id}')
assert response2.status_code == 200
data2 = response2.json()
assert len(data2['items']) == 2
# Collect all items from both pages and verify no duplicates
all_ids = [item['id'] for item in data['items']] + [
item['id'] for item in data2['items']
]
assert len(all_ids) == len(set(all_ids)), 'Pages must not contain duplicate items'
# ---------- Test: user isolation at DB level ----------
@pytest.mark.asyncio
async def test_user_isolation(client, session_maker):
"""User A cannot see or access User B's automations via actual DB queries."""
auto_id = uuid.uuid4().hex
async with session_maker() as session:
other_auto = _make_automation_obj(
automation_id=auto_id,
user_id=OTHER_USER_ID,
name='Other User Auto',
)
session.add(other_auto)
await session.commit()
# Should not be found by TEST_USER_ID
response = await client.get(f'/api/v1/automations/{auto_id}')
assert response.status_code == 404
# Should not appear in search
search_response = await client.get('/api/v1/automations/search')
assert search_response.status_code == 200
assert search_response.json()['total'] == 0

View File

@@ -1,68 +0,0 @@
"""Shared fixtures for services tests.
Note: We pre-load ``storage`` as a namespace package to avoid the heavy
``storage/__init__.py`` that imports the entire enterprise model graph.
This must happen *before* any ``from storage.…`` import.
"""
import contextlib
import sys
import types
# Prevent storage/__init__.py from loading the full model graph.
# We only need the lightweight automation models for these tests.
if 'storage' not in sys.modules:
import pathlib
_storage_dir = str(pathlib.Path(__file__).resolve().parents[3] / 'storage')
_mod = types.ModuleType('storage')
_mod.__path__ = [_storage_dir]
sys.modules['storage'] = _mod
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.automation import Automation, AutomationRun # noqa: F401
from storage.automation_event import AutomationEvent # noqa: F401
from storage.base import Base
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session_factory(async_engine):
"""Create an async session factory that yields context-managed sessions."""
factory = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
@contextlib.asynccontextmanager
async def _session_ctx():
async with factory() as session:
yield session
return _session_ctx
@pytest.fixture
async def async_session(async_session_factory):
"""Create a single async session for testing."""
async with async_session_factory() as session:
yield session

View File

@@ -1,624 +0,0 @@
"""Tests for the automation executor.
Uses real SQLite database operations for event processing, run claiming,
and stale run recovery. HTTP calls to the V1 API are mocked.
"""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from services.automation_executor import (
_mark_run_failed,
claim_and_execute_runs,
find_matching_automations,
is_terminal,
process_new_events,
recover_stale_runs,
utc_now,
)
from sqlalchemy import select
from storage.automation import Automation, AutomationRun
from storage.automation_event import AutomationEvent
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_automation(
automation_id: str = 'auto-1',
user_id: str = 'user-1',
enabled: bool = True,
trigger_type: str = 'cron',
name: str = 'Test Automation',
) -> Automation:
return Automation(
id=automation_id,
user_id=user_id,
org_id='org-1',
name=name,
enabled=enabled,
config={'triggers': {'cron': {'schedule': '0 9 * * 5'}}},
trigger_type=trigger_type,
file_store_key=f'automations/{automation_id}/script.py',
)
def make_event(
source_type: str = 'cron',
payload: dict | None = None,
status: str = 'NEW',
dedup_key: str | None = None,
) -> AutomationEvent:
return AutomationEvent(
source_type=source_type,
payload=payload or {'automation_id': 'auto-1'},
dedup_key=dedup_key or f'dedup-{uuid4().hex[:8]}',
status=status,
created_at=utc_now(),
)
def make_run(
run_id: str | None = None,
automation_id: str = 'auto-1',
status: str = 'PENDING',
claimed_by: str | None = None,
heartbeat_at: datetime | None = None,
retry_count: int = 0,
max_retries: int = 3,
next_retry_at: datetime | None = None,
) -> AutomationRun:
return AutomationRun(
id=run_id or uuid4().hex,
automation_id=automation_id,
status=status,
claimed_by=claimed_by,
heartbeat_at=heartbeat_at,
retry_count=retry_count,
max_retries=max_retries,
next_retry_at=next_retry_at,
event_payload={'automation_id': automation_id},
created_at=utc_now(),
)
# ---------------------------------------------------------------------------
# find_matching_automations
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_find_matching_automations_cron_event(async_session):
"""Cron events match by automation_id in payload."""
automation = make_automation()
async_session.add(automation)
await async_session.commit()
event = make_event(source_type='cron', payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 1
assert result[0].id == 'auto-1'
@pytest.mark.asyncio
async def test_find_matching_automations_manual_event(async_session):
"""Manual events also match by automation_id in payload."""
automation = make_automation()
async_session.add(automation)
await async_session.commit()
event = make_event(source_type='manual', payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 1
assert result[0].id == 'auto-1'
@pytest.mark.asyncio
async def test_find_matching_automations_disabled_automation(async_session):
"""Disabled automations are not matched."""
automation = make_automation(enabled=False)
async_session.add(automation)
await async_session.commit()
event = make_event(payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
@pytest.mark.asyncio
async def test_find_matching_automations_missing_automation_id(async_session):
"""Events without automation_id in payload return empty list."""
event = make_event(payload={'something_else': 'value'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
@pytest.mark.asyncio
async def test_find_matching_automations_nonexistent_automation(async_session):
"""Events referencing a non-existent automation return empty list."""
event = make_event(payload={'automation_id': 'nonexistent'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
@pytest.mark.asyncio
async def test_find_matching_automations_unknown_source_type(async_session):
"""Unknown source types return empty list."""
event = make_event(source_type='unknown', payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
# ---------------------------------------------------------------------------
# process_new_events
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_process_new_events_creates_runs(async_session):
"""Processing NEW events creates PENDING runs and marks events PROCESSED."""
automation = make_automation()
event = make_event(payload={'automation_id': 'auto-1'})
async_session.add_all([automation, event])
await async_session.commit()
count = await process_new_events(async_session)
assert count == 1
# Event should be PROCESSED
await async_session.refresh(event)
assert event.status == 'PROCESSED'
assert event.processed_at is not None
# A run should have been created
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 1
assert runs[0].automation_id == 'auto-1'
assert runs[0].status == 'PENDING'
assert runs[0].event_payload == {'automation_id': 'auto-1'}
@pytest.mark.asyncio
async def test_process_new_events_no_match(async_session):
"""Events with no matching automation are marked NO_MATCH."""
event = make_event(payload={'automation_id': 'nonexistent'})
async_session.add(event)
await async_session.commit()
count = await process_new_events(async_session)
assert count == 1
await async_session.refresh(event)
assert event.status == 'NO_MATCH'
assert event.processed_at is not None
# No runs created
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 0
@pytest.mark.asyncio
async def test_process_new_events_skips_processed(async_session):
"""Already processed events are not re-processed."""
event = make_event(status='PROCESSED')
async_session.add(event)
await async_session.commit()
count = await process_new_events(async_session)
assert count == 0
@pytest.mark.asyncio
async def test_process_new_events_multiple_events(async_session):
"""Multiple NEW events are processed in one batch."""
auto1 = make_automation(automation_id='auto-1')
auto2 = make_automation(automation_id='auto-2', name='Auto 2')
event1 = make_event(payload={'automation_id': 'auto-1'}, dedup_key='dedup-1')
event2 = make_event(payload={'automation_id': 'auto-2'}, dedup_key='dedup-2')
event3 = make_event(payload={'automation_id': 'nonexistent'}, dedup_key='dedup-3')
async_session.add_all([auto1, auto2, event1, event2, event3])
await async_session.commit()
count = await process_new_events(async_session)
assert count == 3
# Two runs created (for auto-1 and auto-2), none for nonexistent
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 2
await async_session.refresh(event1)
await async_session.refresh(event2)
await async_session.refresh(event3)
assert event1.status == 'PROCESSED'
assert event2.status == 'PROCESSED'
assert event3.status == 'NO_MATCH'
# ---------------------------------------------------------------------------
# claim_and_execute_runs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_claim_and_execute_runs_claims_pending(
async_session, async_session_factory
):
"""Claims a PENDING run and transitions to RUNNING."""
automation = make_automation()
run = make_run(run_id='run-1')
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is True
await async_session.refresh(run)
assert run.status == 'RUNNING'
assert run.claimed_by == 'executor-test-1'
assert run.claimed_at is not None
assert run.heartbeat_at is not None
assert run.started_at is not None
@pytest.mark.asyncio
async def test_claim_and_execute_runs_no_pending(async_session, async_session_factory):
"""Returns False when no PENDING runs exist."""
api_client = AsyncMock()
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is False
@pytest.mark.asyncio
async def test_claim_and_execute_runs_respects_next_retry_at(
async_session, async_session_factory
):
"""Runs with future next_retry_at are not claimed."""
automation = make_automation()
run = make_run(
run_id='run-retry',
next_retry_at=utc_now() + timedelta(hours=1),
)
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is False
@pytest.mark.asyncio
async def test_claim_and_execute_runs_past_retry_at(
async_session, async_session_factory
):
"""Runs with past next_retry_at are claimable."""
automation = make_automation()
run = make_run(
run_id='run-retry-past',
next_retry_at=utc_now() - timedelta(minutes=5),
)
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is True
@pytest.mark.asyncio
async def test_claim_skips_running_runs(async_session, async_session_factory):
"""RUNNING runs are not claimed."""
automation = make_automation()
run = make_run(run_id='run-running', status='RUNNING', claimed_by='other-executor')
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is False
# ---------------------------------------------------------------------------
# recover_stale_runs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_recover_stale_runs_recovers_stale(async_session):
"""RUNNING runs with expired heartbeats are recovered to PENDING."""
automation = make_automation()
stale_run = make_run(
run_id='stale-1',
status='RUNNING',
claimed_by='crashed-executor',
heartbeat_at=utc_now() - timedelta(minutes=10),
retry_count=0,
)
async_session.add_all([automation, stale_run])
await async_session.commit()
count = await recover_stale_runs(async_session)
assert count >= 1
await async_session.refresh(stale_run)
assert stale_run.status == 'PENDING'
assert stale_run.claimed_by is None
assert stale_run.retry_count == 1
assert stale_run.next_retry_at is not None
@pytest.mark.asyncio
async def test_recover_stale_runs_ignores_fresh(async_session):
"""RUNNING runs with recent heartbeats are not recovered."""
automation = make_automation()
fresh_run = make_run(
run_id='fresh-1',
status='RUNNING',
claimed_by='active-executor',
heartbeat_at=utc_now() - timedelta(seconds=30),
)
async_session.add_all([automation, fresh_run])
await async_session.commit()
count = await recover_stale_runs(async_session)
assert count == 0
await async_session.refresh(fresh_run)
assert fresh_run.status == 'RUNNING'
assert fresh_run.claimed_by == 'active-executor'
@pytest.mark.asyncio
async def test_recover_stale_runs_ignores_pending(async_session):
"""PENDING runs are not affected by recovery."""
automation = make_automation()
pending_run = make_run(run_id='pending-1', status='PENDING')
async_session.add_all([automation, pending_run])
await async_session.commit()
count = await recover_stale_runs(async_session)
assert count == 0
await async_session.refresh(pending_run)
assert pending_run.status == 'PENDING'
@pytest.mark.asyncio
async def test_recover_stale_runs_increments_retry_count(async_session):
"""Recovery increments the retry_count."""
automation = make_automation()
stale_run = make_run(
run_id='stale-retry',
status='RUNNING',
claimed_by='old-executor',
heartbeat_at=utc_now() - timedelta(minutes=10),
retry_count=2,
)
async_session.add_all([automation, stale_run])
await async_session.commit()
await recover_stale_runs(async_session)
await async_session.refresh(stale_run)
assert stale_run.retry_count == 3
# ---------------------------------------------------------------------------
# _mark_run_failed (error handling)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mark_run_failed_retries(async_session_factory):
"""Failed runs with retries left return to PENDING."""
async with async_session_factory() as session:
automation = make_automation()
run = make_run(run_id='fail-retry', retry_count=0, max_retries=3)
session.add_all([automation, run])
await session.commit()
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-retry')
await _mark_run_failed(run_obj, 'API error', async_session_factory)
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-retry')
assert run_obj.status == 'PENDING'
assert run_obj.retry_count == 1
assert run_obj.error_detail == 'API error'
assert run_obj.next_retry_at is not None
assert run_obj.claimed_by is None
@pytest.mark.asyncio
async def test_mark_run_failed_dead_letter(async_session_factory):
"""Failed runs that exceed max_retries go to DEAD_LETTER."""
async with async_session_factory() as session:
automation = make_automation()
run = make_run(run_id='fail-dead', retry_count=2, max_retries=3)
session.add_all([automation, run])
await session.commit()
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-dead')
await _mark_run_failed(run_obj, 'Final failure', async_session_factory)
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-dead')
assert run_obj.status == 'DEAD_LETTER'
assert run_obj.retry_count == 3
assert run_obj.error_detail == 'Final failure'
assert run_obj.completed_at is not None
# ---------------------------------------------------------------------------
# is_terminal
# ---------------------------------------------------------------------------
def test_is_terminal_stopped():
assert is_terminal({'status': 'STOPPED'}) is True
def test_is_terminal_error():
assert is_terminal({'status': 'ERROR'}) is True
def test_is_terminal_completed():
assert is_terminal({'status': 'COMPLETED'}) is True
def test_is_terminal_cancelled():
assert is_terminal({'status': 'CANCELLED'}) is True
def test_is_terminal_running():
assert is_terminal({'status': 'RUNNING'}) is False
def test_is_terminal_empty():
assert is_terminal({}) is False
def test_is_terminal_case_insensitive():
assert is_terminal({'status': 'stopped'}) is True
assert is_terminal({'status': 'Completed'}) is True
# ---------------------------------------------------------------------------
# find_matching_automations — None payload
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_find_matching_automations_none_payload(async_session):
"""Events with None payload return empty list (data corruption guard)."""
event = make_event(source_type='cron')
event.payload = None
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert result == []
# ---------------------------------------------------------------------------
# Integration: event → run creation → claim
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_integration_event_to_run_to_claim(
async_session_factory,
):
"""Full flow: create event + automation → process_new_events → claim_and_execute_runs.
Uses a real SQLite database; only the external API client is mocked.
"""
# 1. Seed an automation and a NEW event
async with async_session_factory() as session:
automation = make_automation(automation_id='integ-auto')
event = make_event(
source_type='cron',
payload={'automation_id': 'integ-auto'},
dedup_key='integ-dedup',
)
session.add_all([automation, event])
await session.commit()
event_id = event.id
# 2. Process inbox — should match and create a PENDING run
async with async_session_factory() as session:
processed = await process_new_events(session)
assert processed == 1
# Verify event is PROCESSED and run was created
async with async_session_factory() as session:
evt = await session.get(AutomationEvent, event_id)
assert evt.status == 'PROCESSED'
runs = (await session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 1
run = runs[0]
assert run.automation_id == 'integ-auto'
assert run.status == 'PENDING'
assert run.event_payload == {'automation_id': 'integ-auto'}
# 3. Claim the run — mock execute_run to avoid real API calls
api_client = AsyncMock()
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
async with async_session_factory() as session:
claimed = await claim_and_execute_runs(
session, 'executor-integ', api_client, async_session_factory
)
assert claimed is True
# 4. Verify the run moved to RUNNING with correct executor
async with async_session_factory() as session:
runs = (await session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 1
run = runs[0]
assert run.status == 'RUNNING'
assert run.claimed_by == 'executor-integ'
assert run.started_at is not None
assert run.heartbeat_at is not None

View File

@@ -1,185 +0,0 @@
"""Tests for OpenHandsAPIClient with mocked HTTP responses."""
import base64
import httpx
import pytest
from services.openhands_api_client import OpenHandsAPIClient
@pytest.fixture
def api_client():
client = OpenHandsAPIClient(base_url='http://test-server:3000')
yield client
# close handled in tests that need it
# ---------------------------------------------------------------------------
# start_conversation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_start_conversation_sends_correct_request(api_client, respx_mock):
"""start_conversation sends properly formatted POST with auth header."""
automation_file = b'print("hello")'
expected_b64 = base64.b64encode(automation_file).decode()
route = respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(
200,
json={
'app_conversation_id': 'conv-123',
'status': 'RUNNING',
},
)
)
result = await api_client.start_conversation(
api_key='sk-oh-test123',
automation_file=automation_file,
title='Test Automation',
event_payload={'automation_id': 'auto-1'},
)
assert route.called
request = route.calls[0].request
assert request.headers['Authorization'] == 'Bearer sk-oh-test123'
import json
body = json.loads(request.content)
assert body['automation_file'] == expected_b64
assert body['trigger'] == 'automation'
assert body['title'] == 'Test Automation'
assert body['event_payload'] == {'automation_id': 'auto-1'}
assert result == {'app_conversation_id': 'conv-123', 'status': 'RUNNING'}
@pytest.mark.asyncio
async def test_start_conversation_without_event_payload(api_client, respx_mock):
"""start_conversation works with event_payload=None."""
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(200, json={'app_conversation_id': 'conv-456'})
)
result = await api_client.start_conversation(
api_key='sk-oh-test',
automation_file=b'code',
title='Test',
event_payload=None,
)
assert result['app_conversation_id'] == 'conv-456'
@pytest.mark.asyncio
async def test_start_conversation_http_error(api_client, respx_mock):
"""start_conversation raises on HTTP errors."""
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(500, json={'error': 'Internal Server Error'})
)
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await api_client.start_conversation(
api_key='sk-oh-test',
automation_file=b'code',
title='Test',
)
assert exc_info.value.response.status_code == 500
@pytest.mark.asyncio
async def test_start_conversation_auth_error(api_client, respx_mock):
"""start_conversation raises on 401 Unauthorized."""
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(401, json={'error': 'Unauthorized'})
)
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await api_client.start_conversation(
api_key='bad-key',
automation_file=b'code',
title='Test',
)
assert exc_info.value.response.status_code == 401
# ---------------------------------------------------------------------------
# get_conversation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_conversation_returns_data(api_client, respx_mock):
"""get_conversation returns the first conversation from the list."""
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(
200,
json=[
{
'conversation_id': 'conv-123',
'status': 'RUNNING',
'title': 'My Automation',
}
],
)
)
result = await api_client.get_conversation('sk-oh-test', 'conv-123')
assert result is not None
assert result['conversation_id'] == 'conv-123'
assert result['status'] == 'RUNNING'
@pytest.mark.asyncio
async def test_get_conversation_returns_none_when_empty(api_client, respx_mock):
"""get_conversation returns None when API returns empty list."""
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(200, json=[])
)
result = await api_client.get_conversation('sk-oh-test', 'nonexistent')
assert result is None
@pytest.mark.asyncio
async def test_get_conversation_sends_auth_header(api_client, respx_mock):
"""get_conversation sends the correct authorization header."""
route = respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(200, json=[])
)
await api_client.get_conversation('sk-oh-mykey', 'conv-1')
assert route.called
request = route.calls[0].request
assert request.headers['Authorization'] == 'Bearer sk-oh-mykey'
@pytest.mark.asyncio
async def test_get_conversation_http_error(api_client, respx_mock):
"""get_conversation raises on HTTP errors."""
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(503, text='Service Unavailable')
)
with pytest.raises(httpx.HTTPStatusError):
await api_client.get_conversation('sk-oh-test', 'conv-1')
# ---------------------------------------------------------------------------
# close
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_close(api_client):
"""close() shuts down the HTTP client without errors."""
await api_client.close()