Compare commits

..

6 Commits

Author SHA1 Message Date
openhands
215e769735 fix: remove dead params from file generator, add timezone validation
- Remove unused repository and branch parameters from generate_automation_file()
- Add field_validator for timezone in CronTriggerModel using zoneinfo.ZoneInfo

Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-11 18:34:15 +00:00
openhands
03a816333b fix: align model server_default with migration (use text CURRENT_TIMESTAMP) 2026-03-11 18:22:50 +00:00
openhands
54433c5dae fix: rewrite publisher tests to use real imports, remove dead JSONB import
Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-11 18:11:51 +00:00
openhands
9269d045c8 fix: migration JSONB consistency, docstring escaping edge case, add name triple-quote test
- Replace sa.dialects.postgresql.JSONB() with sa.JSON() in migration for
  consistency with model definitions
- Sanitize name in docstring by replacing double quotes with single quotes
  to prevent triple-quote breakage
- Add test_triple_quotes_in_name test case

Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-11 17:53:35 +00:00
openhands
29d8990263 fix: address review feedback on automation data foundation
- Verify migration down_revision='099' is correct (no change needed)
- Change automation_events.id to BigInteger (BIGSERIAL) in migration and models
- Change event_id FK to BigInteger in AutomationRun
- Fix ix_automation_events_new index to be on created_at instead of status
- Fix updated_at to use func.now() for timezone-aware datetime
- Replace custom cron regex with croniter.is_valid() for validation
- Fix file generator escaping bug using repr() for safe string handling
- Add triple-quote test for file generator
- Add publisher tests (test_automation_event_publisher.py)
- Update test assertions to match new croniter error message

Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-11 08:27:09 +00:00
openhands
0c9af8290f [Automations Phase 1] Task 1: Data Foundation
- Add AUTOMATION variant to ConversationTrigger enum
- Create Automation and AutomationRun SQLAlchemy models (enterprise/storage/automation.py)
- Create AutomationEvent SQLAlchemy model (enterprise/storage/automation_event.py)
- Create Alembic migration 100 for automation tables with indexes
- Create automation config extraction and Pydantic validation (enterprise/services/automation_config.py)
- Create automation file generator (enterprise/services/automation_file_generator.py)
- Create automation event publisher with pg_notify (enterprise/services/automation_event_publisher.py)
- Add comprehensive unit tests (31 tests, all passing)
- Add croniter dependency to enterprise pyproject.toml

Part of RFC: https://github.com/OpenHands/OpenHands/issues/13275

Co-authored-by: openhands <openhands@all-hands.dev>
2026-03-10 18:39:44 +00:00
19 changed files with 967 additions and 1638 deletions

View File

@@ -0,0 +1,133 @@
"""Create automation tables (automations, automation_events, automation_runs)
Revision ID: 100
Revises: 099
Create Date: 2025-03-10 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '100'
down_revision: Union[str, None] = '099'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# --- automation_events (must come first, referenced by automation_runs) ---
op.create_table(
'automation_events',
sa.Column('id', sa.BigInteger(), sa.Identity(), nullable=False, primary_key=True),
sa.Column('source_type', sa.String(), nullable=False),
sa.Column('payload', sa.JSON(), nullable=False),
sa.Column('metadata', sa.JSON(), nullable=True),
sa.Column('dedup_key', sa.String(), nullable=False),
sa.Column('status', sa.String(), nullable=False, server_default=sa.text("'NEW'")),
sa.Column('error_detail', sa.Text(), nullable=True),
sa.Column(
'created_at',
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column('processed_at', sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('dedup_key', name='uq_automation_events_dedup'),
)
op.create_index(
'ix_automation_events_new',
'automation_events',
['created_at'],
postgresql_where=sa.text("status = 'NEW'"),
)
# --- automations ---
op.create_table(
'automations',
sa.Column('id', sa.String(), nullable=False, primary_key=True),
sa.Column('user_id', sa.String(), nullable=False),
sa.Column('org_id', sa.String(), nullable=True),
sa.Column('name', sa.String(), nullable=False),
sa.Column('enabled', sa.Boolean(), nullable=False, server_default=sa.text('TRUE')),
sa.Column('config', sa.JSON(), nullable=False),
sa.Column('trigger_type', sa.String(), nullable=False),
sa.Column('file_store_key', sa.String(), nullable=False),
sa.Column('last_triggered_at', sa.DateTime(timezone=True), nullable=True),
sa.Column(
'created_at',
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column(
'updated_at',
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.PrimaryKeyConstraint('id'),
)
op.create_index('ix_automations_user_id', 'automations', ['user_id'])
op.create_index('ix_automations_org_id', 'automations', ['org_id'])
op.create_index('ix_automations_enabled_trigger', 'automations', ['enabled', 'trigger_type'])
# --- automation_runs ---
op.create_table(
'automation_runs',
sa.Column('id', sa.String(), nullable=False, primary_key=True),
sa.Column(
'automation_id',
sa.String(),
sa.ForeignKey('automations.id', ondelete='CASCADE'),
nullable=False,
),
sa.Column(
'event_id',
sa.BigInteger(),
sa.ForeignKey('automation_events.id'),
nullable=True,
),
sa.Column('conversation_id', sa.String(), nullable=True),
sa.Column('status', sa.String(), nullable=False, server_default=sa.text("'PENDING'")),
sa.Column('claimed_by', sa.String(), nullable=True),
sa.Column('claimed_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('heartbeat_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('retry_count', sa.Integer(), nullable=False, server_default=sa.text('0')),
sa.Column('max_retries', sa.Integer(), nullable=False, server_default=sa.text('3')),
sa.Column('next_retry_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('event_payload', sa.JSON(), nullable=True),
sa.Column('error_detail', sa.Text(), nullable=True),
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
sa.Column(
'created_at',
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.PrimaryKeyConstraint('id'),
)
op.create_index('ix_automation_runs_automation_id', 'automation_runs', ['automation_id'])
op.create_index(
'ix_automation_runs_claimable',
'automation_runs',
['status', 'next_retry_at'],
postgresql_where=sa.text("status = 'PENDING' AND (next_retry_at IS NULL OR next_retry_at <= now())"),
)
op.create_index(
'ix_automation_runs_heartbeat',
'automation_runs',
['heartbeat_at'],
postgresql_where=sa.text("status = 'RUNNING'"),
)
def downgrade() -> None:
op.drop_table('automation_runs')
op.drop_table('automations')
op.drop_table('automation_events')

36
enterprise/poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand.
[[package]]
name = "agent-client-protocol"
@@ -1641,6 +1641,22 @@ files = [
{file = "crashtest-0.4.1.tar.gz", hash = "sha256:80d7b1f316ebfbd429f648076d6275c877ba30ba48979de4191714a75266f0ce"},
]
[[package]]
name = "croniter"
version = "6.0.0"
description = "croniter provides iteration for datetime object with cron like format"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.6"
groups = ["main"]
files = [
{file = "croniter-6.0.0-py2.py3-none-any.whl", hash = "sha256:2f878c3856f17896979b2a4379ba1f09c83e374931ea15cc835c5dd2eee9b368"},
{file = "croniter-6.0.0.tar.gz", hash = "sha256:37c504b313956114a983ece2c2b07790b1f1094fe9d81cc94739214748255577"},
]
[package.dependencies]
python-dateutil = "*"
pytz = ">2021.1"
[[package]]
name = "cryptography"
version = "46.0.5"
@@ -3501,7 +3517,7 @@ files = [
[package.dependencies]
googleapis-common-protos = ">=1.5.5"
grpcio = ">=1.71.2"
protobuf = ">=5.26.1,<6.0dev"
protobuf = ">=5.26.1,<6.0.dev0"
[[package]]
name = "gspread"
@@ -3819,7 +3835,7 @@ pfzy = ">=0.3.1,<0.4.0"
prompt-toolkit = ">=3.0.1,<4.0.0"
[package.extras]
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17-beta.43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17b43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
[[package]]
name = "installer"
@@ -4258,7 +4274,7 @@ fqdn = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
idna = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
isoduration = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
jsonpointer = {version = ">1.13", optional = true, markers = "extra == \"format-nongpl\""}
jsonschema-specifications = ">=2023.03.6"
jsonschema-specifications = ">=2023.3.6"
referencing = ">=0.28.4"
rfc3339-validator = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
rfc3986-validator = {version = ">0.1.0", optional = true, markers = "extra == \"format-nongpl\""}
@@ -4648,7 +4664,7 @@ files = [
]
[package.dependencies]
certifi = ">=14.05.14"
certifi = ">=14.5.14"
durationpy = ">=0.7"
google-auth = ">=1.0.1"
oauthlib = ">=3.2.2"
@@ -6889,7 +6905,7 @@ files = [
]
[package.extras]
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17-beta.43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17b43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
[[package]]
name = "pg8000"
@@ -12866,10 +12882,10 @@ files = [
]
[package.dependencies]
botocore = ">=1.37.4,<2.0a.0"
botocore = ">=1.37.4,<2.0a0"
[package.extras]
crt = ["botocore[crt] (>=1.37.4,<2.0a.0)"]
crt = ["botocore[crt] (>=1.37.4,<2.0a0)"]
[[package]]
name = "scantree"
@@ -15006,9 +15022,9 @@ files = [
]
[package.extras]
cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and python_version < \"3.14\"", "cffi (>=2.0.0b) ; platform_python_implementation != \"PyPy\" and python_version >= \"3.14\""]
cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and python_version < \"3.14\"", "cffi (>=2.0.0b0) ; platform_python_implementation != \"PyPy\" and python_version >= \"3.14\""]
[metadata]
lock-version = "2.1"
python-versions = "^3.12,<3.14"
content-hash = "ef037f6d6085d26166d35c56ce266439f8f1a4fea90bc43ccf15cfeaf116cae5"
content-hash = "d07fdf5fbc8eaf4ed30c119b0c05081d0eb3df60732e530002db5eec84ef080b"

View File

@@ -17,6 +17,7 @@ packages = [
{ include = "storage" },
{ include = "sync" },
{ include = "integrations" },
{ include = "services" },
]
[tool.poetry.dependencies]
@@ -49,6 +50,7 @@ pandas = "^2.2.0"
numpy = "^2.2.0"
mcp = "^1.10.0"
pillow = "^12.1.1"
croniter = "^6.0.0"
[tool.poetry.group.dev.dependencies]
ruff = "0.8.3"

View File

@@ -1,71 +0,0 @@
"""Entry point for the automation executor.
Usage: python -m run_automation_executor
This runs as a Kubernetes Deployment (long-running). It polls the automation_events
inbox, matches events to automations, claims and executes runs, and monitors
conversation completion.
Environment variables:
OPENHANDS_API_URL Base URL for the V1 API (default: http://openhands-service:3000)
MAX_CONCURRENT_RUNS Max concurrent runs per executor (default: 5)
RUN_TIMEOUT_SECONDS Max time for a single run (default: 7200)
POLL_INTERVAL_SECONDS Fallback poll interval (default: 30)
HEARTBEAT_INTERVAL_SECONDS Heartbeat update interval (default: 60)
"""
import asyncio
import logging
import signal
import sys
logger = logging.getLogger('saas.automation.executor')
def _setup_logging() -> None:
"""Configure logging, deferring to enterprise logger if available."""
try:
from server.logger import setup_all_loggers
setup_all_loggers()
except ImportError:
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(name)s %(levelname)s %(message)s',
stream=sys.stdout,
)
def _install_signal_handlers(loop: asyncio.AbstractEventLoop) -> None:
"""Install signal handlers for graceful shutdown."""
from services.automation_executor import request_shutdown
def _handle_signal(signum: int, _frame: object) -> None:
sig_name = signal.Signals(signum).name
logger.info('Received %s, initiating graceful shutdown...', sig_name)
request_shutdown()
for sig in (signal.SIGTERM, signal.SIGINT):
signal.signal(sig, _handle_signal)
async def main() -> None:
from services.automation_executor import executor_main
await executor_main()
if __name__ == '__main__':
_setup_logging()
loop = asyncio.new_event_loop()
_install_signal_handlers(loop)
logger.info('Starting automation executor')
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
logger.info('Interrupted by user')
finally:
loop.close()
logger.info('Automation executor process exiting')

View File

@@ -0,0 +1,121 @@
"""Automation config extraction and validation.
Parses ``__config__`` from automation Python source files
and validates it against a Pydantic schema.
"""
from __future__ import annotations
import ast
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from croniter import croniter
from pydantic import BaseModel, field_validator, model_validator
def extract_config(source: str) -> dict:
"""Extract the ``__config__`` dict from automation Python source code.
Uses :mod:`ast` to safely parse the source and locate a module-level
assignment to ``__config__``. The value must be a literal expression
(evaluated via :func:`ast.literal_eval`).
Raises:
ValueError: If ``__config__`` is not found or its value contains
non-literal expressions.
"""
try:
tree = ast.parse(source)
except SyntaxError as exc:
raise ValueError(f'Failed to parse source: {exc}') from exc
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == '__config__':
try:
value = ast.literal_eval(node.value)
except (ValueError, TypeError) as exc:
raise ValueError(
f'__config__ value must be a literal expression: {exc}'
) from exc
if not isinstance(value, dict):
raise ValueError('__config__ must be a dict')
return value
# Handle annotated assignment: __config__: dict = {...}
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
if node.target.id == '__config__' and node.value is not None:
try:
value = ast.literal_eval(node.value)
except (ValueError, TypeError) as exc:
raise ValueError(
f'__config__ value must be a literal expression: {exc}'
) from exc
if not isinstance(value, dict):
raise ValueError('__config__ must be a dict')
return value
raise ValueError('__config__ not found in source')
class CronTriggerModel(BaseModel):
"""Cron trigger configuration."""
schedule: str
timezone: str = 'UTC'
@field_validator('schedule')
@classmethod
def validate_schedule(cls, v: str) -> str:
v = v.strip()
if not croniter.is_valid(v):
raise ValueError(f'Invalid cron expression: {v!r}')
return v
@field_validator('timezone')
@classmethod
def validate_timezone(cls, v: str) -> str:
v = v.strip()
try:
ZoneInfo(v)
except (ZoneInfoNotFoundError, KeyError):
raise ValueError(f'Invalid timezone: {v!r}')
return v
class TriggersModel(BaseModel):
"""Container for trigger definitions. Exactly one trigger must be set."""
cron: CronTriggerModel | None = None
@model_validator(mode='after')
def exactly_one_trigger(self) -> TriggersModel:
defined = [name for name in ('cron',) if getattr(self, name) is not None]
if len(defined) != 1:
raise ValueError(f'Exactly one trigger must be defined, got: {defined}')
return self
class AutomationConfigModel(BaseModel):
"""Top-level automation config schema."""
name: str
triggers: TriggersModel
description: str = ''
@field_validator('name')
@classmethod
def validate_name(cls, v: str) -> str:
v = v.strip()
if not (1 <= len(v) <= 200):
raise ValueError('name must be between 1 and 200 characters')
return v
def validate_config(config: dict) -> AutomationConfigModel:
"""Validate a ``__config__`` dict against the automation schema.
Returns the parsed :class:`AutomationConfigModel` or raises
:class:`pydantic.ValidationError`.
"""
return AutomationConfigModel.model_validate(config)

View File

@@ -0,0 +1,36 @@
"""Publish automation events and notify listeners via PostgreSQL NOTIFY."""
from __future__ import annotations
from sqlalchemy import text
from sqlalchemy.orm import Session
from storage.automation_event import AutomationEvent
def publish_automation_event(
session: Session,
source_type: str,
payload: dict,
dedup_key: str,
metadata: dict | None = None,
) -> AutomationEvent:
"""Create an :class:`AutomationEvent` and add it to the session.
The caller is responsible for committing (or flushing) the session.
"""
event = AutomationEvent(
source_type=source_type,
payload=payload,
dedup_key=dedup_key,
metadata_=metadata,
)
session.add(event)
return event
def pg_notify_new_event(session: Session, event_id: int) -> None:
"""Send a PostgreSQL ``NOTIFY`` on the ``automation_events`` channel."""
session.execute(
text("SELECT pg_notify('automation_events', :event_id)"),
{'event_id': str(event_id)},
)

View File

@@ -1,555 +0,0 @@
"""Automation executor — processes events, claims and executes runs.
The executor is a long-running process with three phases:
1. Process inbox: match NEW events to automations, create PENDING runs
2. Claim and execute: claim PENDING runs, submit to V1 API, heartbeat
3. Stale recovery: recover RUNNING runs with expired heartbeats
"""
import asyncio
import logging
import os
import socket
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from services.openhands_api_client import OpenHandsAPIClient
from sqlalchemy import or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from storage.automation import Automation, AutomationRun
from storage.automation_event import AutomationEvent
logger = logging.getLogger('saas.automation.executor')
# Environment-configurable settings
POLL_INTERVAL_SECONDS = float(os.getenv('POLL_INTERVAL_SECONDS', '30'))
HEARTBEAT_INTERVAL_SECONDS = float(os.getenv('HEARTBEAT_INTERVAL_SECONDS', '60'))
RUN_TIMEOUT_SECONDS = float(os.getenv('RUN_TIMEOUT_SECONDS', '7200'))
MAX_CONCURRENT_RUNS = int(os.getenv('MAX_CONCURRENT_RUNS', '5'))
STALE_THRESHOLD_MINUTES = 5
MAX_EVENTS_PER_BATCH = 50
MAX_RETRIES_DEFAULT = 3
# Terminal conversation statuses
TERMINAL_STATUSES = frozenset({'STOPPED', 'ERROR', 'COMPLETED', 'CANCELLED'})
# Shutdown flag — set by signal handlers
_shutdown_event: asyncio.Event | None = None
# Background task tracking for graceful shutdown
_pending_tasks: set[asyncio.Task] = set()
def utc_now() -> datetime:
return datetime.now(timezone.utc)
def get_shutdown_event() -> asyncio.Event:
global _shutdown_event
if _shutdown_event is None:
_shutdown_event = asyncio.Event()
return _shutdown_event
def should_continue() -> bool:
return not get_shutdown_event().is_set()
def request_shutdown() -> None:
get_shutdown_event().set()
# ---------------------------------------------------------------------------
# Phase 1: Process inbox (event matching)
# ---------------------------------------------------------------------------
async def find_matching_automations(
session: AsyncSession, event: AutomationEvent
) -> list[Automation]:
"""Find automations that match the given event.
Phase 1 supports cron and manual triggers only — both carry
``automation_id`` in the event payload.
"""
source_type = event.source_type
payload = event.payload
if payload is None:
logger.error('Event %s has None payload — possible data corruption', event.id)
return []
if source_type in ('cron', 'manual'):
automation_id = payload.get('automation_id')
if not automation_id:
logger.warning(
'Event %s (source=%s) missing automation_id in payload',
event.id,
source_type,
)
return []
result = await session.execute(
select(Automation).where(
Automation.id == automation_id,
Automation.enabled.is_(True),
)
)
automation = result.scalar_one_or_none()
return [automation] if automation else []
logger.debug('Unhandled event source_type=%s for event %s', source_type, event.id)
return []
async def process_new_events(session: AsyncSession) -> int:
"""Claim NEW events from inbox, match to automations, create runs.
Returns the number of events processed.
"""
result = await session.execute(
select(AutomationEvent)
.where(AutomationEvent.status == 'NEW')
.order_by(AutomationEvent.created_at)
.limit(MAX_EVENTS_PER_BATCH)
.with_for_update(skip_locked=True)
)
events = list(result.scalars())
processed = 0
for event in events:
try:
automations = await find_matching_automations(session, event)
if not automations:
event.status = 'NO_MATCH'
event.processed_at = utc_now()
else:
for automation in automations:
run = AutomationRun(
id=uuid4().hex,
automation_id=automation.id,
event_id=event.id,
status='PENDING',
event_payload=event.payload,
)
session.add(run)
event.status = 'PROCESSED'
event.processed_at = utc_now()
processed += 1
except Exception as e:
logger.exception('Error processing event %s', event.id)
event.status = 'ERROR'
event.error_detail = f'Failed during event matching: {type(e).__name__}: {e}'
event.processed_at = utc_now()
if processed:
await session.commit()
logger.info('Processed %d events', processed)
return processed
# ---------------------------------------------------------------------------
# Phase 2: Claim and execute runs
# ---------------------------------------------------------------------------
async def resolve_user_api_key(session: AsyncSession, user_id: str) -> str | None:
"""Look up a user's API key from the api_keys table.
Returns the first active key found, or None.
"""
from storage.api_key import ApiKey
result = await session.execute(
select(ApiKey.key).where(ApiKey.user_id == user_id).limit(1)
)
row = result.scalar_one_or_none()
return row
async def download_automation_file(file_store_key: str) -> bytes:
"""Download the automation .py file from object storage."""
try:
from openhands.server.shared import file_store
except ImportError as exc:
raise RuntimeError(
'file_store is not available — ensure the enterprise server '
'has been initialised before calling download_automation_file'
) from exc
content = file_store.read(file_store_key)
if isinstance(content, str):
return content.encode('utf-8')
return content
def is_terminal(conversation: dict) -> bool:
"""Check if a conversation has reached a terminal status."""
status = (conversation.get('status') or '').upper()
return status in TERMINAL_STATUSES
async def _prepare_run(
run: AutomationRun,
automation: Automation,
session_factory: object,
) -> tuple[str, bytes]:
"""Resolve the user's API key and download the automation file.
Returns:
(api_key, automation_file) tuple ready for submission.
Raises:
ValueError: If no API key is found.
RuntimeError: If file_store is unavailable.
"""
async with session_factory() as key_session:
api_key = await resolve_user_api_key(key_session, automation.user_id)
if not api_key:
raise ValueError(f'No API key found for user {automation.user_id}')
automation_file = await download_automation_file(automation.file_store_key)
return api_key, automation_file
async def _monitor_conversation(
run: AutomationRun,
conversation_id: str,
api_client: OpenHandsAPIClient,
api_key: str,
session_factory: object,
) -> bool:
"""Monitor a conversation until completion or timeout.
Returns True if completed successfully, False if shutdown requested.
Raises:
TimeoutError: If the run exceeds RUN_TIMEOUT_SECONDS.
"""
start_time = utc_now()
while should_continue():
elapsed = (utc_now() - start_time).total_seconds()
if elapsed > RUN_TIMEOUT_SECONDS:
raise TimeoutError(f'Run exceeded {RUN_TIMEOUT_SECONDS}s timeout')
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
# Update heartbeat
async with session_factory() as session:
run_obj = await session.get(AutomationRun, run.id)
if run_obj:
run_obj.heartbeat_at = utc_now()
await session.commit()
# Check conversation status
conversation = (
await api_client.get_conversation(api_key, conversation_id) or {}
)
if is_terminal(conversation):
return True
return False # shutdown requested
async def _submit_and_monitor(
run: AutomationRun,
api_key: str,
automation_file: bytes,
automation: Automation,
api_client: OpenHandsAPIClient,
session_factory: object,
) -> None:
"""Submit the automation to the V1 API and monitor until completion.
Updates the run's conversation_id, sends heartbeats, and marks the
final status when the conversation reaches a terminal state.
"""
conversation = await api_client.start_conversation(
api_key=api_key,
automation_file=automation_file,
title=f'Automation: {automation.name}',
event_payload=run.event_payload,
)
conversation_id = conversation.get('app_conversation_id') or conversation.get(
'conversation_id'
)
# Persist conversation ID
async with session_factory() as update_session:
run_obj = await update_session.get(AutomationRun, run.id)
if run_obj:
run_obj.conversation_id = conversation_id
await update_session.commit()
# Monitor with heartbeats
completed = await _monitor_conversation(
run, conversation_id, api_client, api_key, session_factory
)
# Update final status
async with session_factory() as final_session:
run_obj = await final_session.get(AutomationRun, run.id)
if run_obj:
if not completed:
# Leave as RUNNING — stale recovery will handle it if needed.
# The conversation may still be running on the API side.
logger.info(
'Run %s left as RUNNING due to executor shutdown', run.id
)
else:
run_obj.status = 'COMPLETED'
run_obj.completed_at = utc_now()
logger.info('Run %s completed successfully', run.id)
await final_session.commit()
async def execute_run(
run: AutomationRun,
automation: Automation,
api_client: OpenHandsAPIClient,
session_factory: object,
) -> None:
"""Execute a single automation run end-to-end.
Orchestrates preparation (API key + file download) and submission/monitoring.
On failure, marks the run for retry or dead-letter.
"""
try:
api_key, automation_file = await _prepare_run(
run, automation, session_factory
)
await _submit_and_monitor(
run, api_key, automation_file, automation, api_client, session_factory
)
except Exception as e:
logger.exception('Run %s failed: %s', run.id, e)
await _mark_run_failed(run, str(e), session_factory)
async def _mark_run_failed(
run: AutomationRun, error: str, session_factory: object
) -> None:
"""Mark a run as FAILED or return to PENDING for retry."""
async with session_factory() as session:
run_obj = await session.get(AutomationRun, run.id)
if not run_obj:
return
run_obj.retry_count = (run_obj.retry_count or 0) + 1
run_obj.error_detail = error
if run_obj.retry_count >= (run_obj.max_retries or MAX_RETRIES_DEFAULT):
run_obj.status = 'DEAD_LETTER'
run_obj.completed_at = utc_now()
logger.error(
'Run %s moved to DEAD_LETTER after %d retries',
run.id,
run_obj.retry_count,
)
else:
run_obj.status = 'PENDING'
run_obj.claimed_by = None
backoff_seconds = 30 * (2 ** (run_obj.retry_count - 1))
run_obj.next_retry_at = utc_now() + timedelta(seconds=backoff_seconds)
logger.warning(
'Run %s returned to PENDING, retry %d/%d in %ds',
run.id,
run_obj.retry_count,
run_obj.max_retries or MAX_RETRIES_DEFAULT,
backoff_seconds,
)
await session.commit()
async def claim_and_execute_runs(
session: AsyncSession,
executor_id: str,
api_client: OpenHandsAPIClient,
session_factory: object,
) -> bool:
"""Claim a PENDING run and start executing it.
Returns True if a run was claimed, False otherwise.
"""
result = await session.execute(
select(AutomationRun)
.where(
AutomationRun.status == 'PENDING',
or_(
AutomationRun.next_retry_at.is_(None),
AutomationRun.next_retry_at <= utc_now(),
),
)
.order_by(AutomationRun.created_at)
.limit(1)
.with_for_update(skip_locked=True)
)
run = result.scalar_one_or_none()
if not run:
return False
# Claim the run
run.status = 'RUNNING'
run.claimed_by = executor_id
run.claimed_at = utc_now()
run.heartbeat_at = utc_now()
run.started_at = utc_now()
await session.commit()
# Load automation for the run
auto_result = await session.execute(
select(Automation).where(Automation.id == run.automation_id)
)
automation = auto_result.scalar_one_or_none()
if not automation:
logger.error('Automation %s not found for run %s', run.automation_id, run.id)
await _mark_run_failed(
run, f'Automation {run.automation_id} not found', session_factory
)
return True
# Execute in background (long-running) with task tracking
task = asyncio.create_task(
execute_run(run, automation, api_client, session_factory),
name=f'execute-run-{run.id}',
)
_pending_tasks.add(task)
task.add_done_callback(_pending_tasks.discard)
logger.info(
'Claimed run %s (automation=%s) by executor %s',
run.id,
run.automation_id,
executor_id,
)
return True
# ---------------------------------------------------------------------------
# Phase 3: Stale run recovery
# ---------------------------------------------------------------------------
async def recover_stale_runs(session: AsyncSession) -> int:
"""Mark RUNNING runs with expired heartbeats as PENDING for retry.
Returns the number of recovered runs.
"""
stale_threshold = utc_now() - timedelta(minutes=STALE_THRESHOLD_MINUTES)
timeout_threshold = utc_now() - timedelta(seconds=RUN_TIMEOUT_SECONDS)
# Recover stale runs (heartbeat expired)
result = await session.execute(
update(AutomationRun)
.where(
AutomationRun.status == 'RUNNING',
AutomationRun.heartbeat_at < stale_threshold,
AutomationRun.heartbeat_at >= timeout_threshold,
)
.values(
status='PENDING',
claimed_by=None,
retry_count=AutomationRun.retry_count + 1,
next_retry_at=utc_now() + timedelta(seconds=30),
)
.returning(AutomationRun.id)
)
recovered_rows = result.fetchall()
# Mark truly timed-out runs as DEAD_LETTER
timeout_result = await session.execute(
update(AutomationRun)
.where(
AutomationRun.status == 'RUNNING',
AutomationRun.heartbeat_at < timeout_threshold,
)
.values(
status='DEAD_LETTER',
error_detail='Run exceeded timeout',
completed_at=utc_now(),
)
.returning(AutomationRun.id)
)
timed_out_rows = timeout_result.fetchall()
await session.commit()
recovered_count = len(recovered_rows)
timed_out_count = len(timed_out_rows)
if recovered_count:
logger.warning('Recovered %d stale automation runs', recovered_count)
if timed_out_count:
logger.warning(
'Marked %d automation runs as DEAD_LETTER (timeout)', timed_out_count
)
return recovered_count + timed_out_count
# ---------------------------------------------------------------------------
# Main executor loop
# ---------------------------------------------------------------------------
async def executor_main(session_factory: object | None = None) -> None:
"""Main executor loop.
Args:
session_factory: Async context manager that yields AsyncSession instances.
If None, uses the default ``a_session_maker`` from database module.
"""
if session_factory is None:
from storage.database import a_session_maker
session_factory = a_session_maker
executor_id = f'executor-{socket.gethostname()}-{os.getpid()}'
api_url = os.getenv('OPENHANDS_API_URL', 'http://openhands-service:3000')
api_client = OpenHandsAPIClient(base_url=api_url)
logger.info(
'Automation executor %s starting (api_url=%s, poll=%ss, heartbeat=%ss)',
executor_id,
api_url,
POLL_INTERVAL_SECONDS,
HEARTBEAT_INTERVAL_SECONDS,
)
try:
while should_continue():
try:
async with session_factory() as session:
await process_new_events(session)
async with session_factory() as session:
await claim_and_execute_runs(
session, executor_id, api_client, session_factory
)
async with session_factory() as session:
await recover_stale_runs(session)
except Exception:
logger.exception('Error in executor main loop iteration')
# Wait for next poll interval (or early wakeup on shutdown)
try:
await asyncio.wait_for(
get_shutdown_event().wait(),
timeout=POLL_INTERVAL_SECONDS,
)
except asyncio.TimeoutError:
pass # Normal — poll interval elapsed
finally:
if _pending_tasks:
logger.info(
'Waiting for %d running tasks to complete...', len(_pending_tasks)
)
await asyncio.gather(*_pending_tasks, return_exceptions=True)
await api_client.close()
logger.info('Automation executor %s shut down', executor_id)

View File

@@ -0,0 +1,56 @@
"""Generate automation Python files from user-provided parameters."""
from __future__ import annotations
import textwrap
def generate_automation_file(
name: str,
schedule: str,
timezone: str,
prompt: str,
) -> str:
"""Return a complete, valid Python file string for an automation.
The generated file includes a ``__config__`` dict that can be round-tripped
through :func:`services.automation_config.extract_config` and
:func:`services.automation_config.validate_config`.
"""
# Use repr() for safe string escaping — handles backslashes, quotes, etc.
r_name = repr(name)
r_schedule = repr(schedule)
r_timezone = repr(timezone)
r_prompt = repr(prompt)
# Build a safe docstring — replace double quotes with single quotes to
# prevent triple-quote breakage in the docstring.
safe_docstring_name = name.replace('"', "'")
return textwrap.dedent(f'''\
"""{safe_docstring_name} — auto-generated automation."""
__config__ = {{
"name": {r_name},
"triggers": {{
"cron": {{
"schedule": {r_schedule},
"timezone": {r_timezone},
}}
}},
}}
import os
from openhands.sdk import LLM, Conversation
from openhands.tools.preset.default import get_default_agent
llm = LLM(
model=os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-5-20250929"),
api_key=os.getenv("LLM_API_KEY"),
base_url=os.getenv("LLM_BASE_URL"),
)
agent = get_default_agent(llm=llm, cli_mode=True)
conversation = Conversation(agent=agent, workspace=os.getcwd())
conversation.send_message({r_prompt})
conversation.run()
''')

View File

@@ -1,93 +0,0 @@
"""HTTP client for the main OpenHands V1 API (internal cluster calls).
Used by the automation executor to create and monitor conversations
in the main OpenHands server.
"""
import base64
import logging
import httpx
logger = logging.getLogger('saas.automation.api_client')
def _raise_with_body(resp: httpx.Response) -> None:
"""Call raise_for_status, enriching the error with the response body."""
try:
resp.raise_for_status()
except httpx.HTTPStatusError as e:
error_body = resp.text[:500] if resp.text else 'no response body'
raise httpx.HTTPStatusError(
f'{e.args[0]} — Response: {error_body}',
request=e.request,
response=e.response,
) from e
class OpenHandsAPIClient:
"""Async HTTP client for the OpenHands V1 API."""
def __init__(self, base_url: str = 'http://openhands-service:3000'):
self.base_url = base_url.rstrip('/')
self.client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0)
async def start_conversation(
self,
api_key: str,
automation_file: bytes,
title: str,
event_payload: dict | None = None,
) -> dict:
"""Submit an SDK script for sandboxed execution via V1 API.
Args:
api_key: User's API key for authentication.
automation_file: Raw bytes of the .py automation script.
title: Display title for the conversation.
event_payload: Optional trigger event data (injected as env var).
Returns:
Parsed JSON response containing conversation details.
Raises:
httpx.HTTPStatusError: If the API returns a non-2xx status.
"""
resp = await self.client.post(
'/api/v1/app-conversations',
json={
'automation_file': base64.b64encode(automation_file).decode(),
'trigger': 'automation',
'title': title,
'event_payload': event_payload,
},
headers={'Authorization': f'Bearer {api_key}'},
)
_raise_with_body(resp)
return resp.json()
async def get_conversation(self, api_key: str, conversation_id: str) -> dict | None:
"""Get conversation status.
Args:
api_key: User's API key for authentication.
conversation_id: The conversation ID to look up.
Returns:
Conversation data dict, or None if not found.
Raises:
httpx.HTTPStatusError: If the API returns a non-2xx status.
"""
resp = await self.client.get(
'/api/v1/app-conversations',
params={'ids': [conversation_id]},
headers={'Authorization': f'Bearer {api_key}'},
)
_raise_with_body(resp)
conversations = resp.json()
return conversations[0] if conversations else None
async def close(self) -> None:
"""Close the underlying HTTP client."""
await self.client.aclose()

View File

@@ -1,61 +1,78 @@
"""SQLAlchemy models for automations and automation runs.
Stub for Task 1 (Data Foundation). These models will be replaced when Task 1
is merged into automations-phase1.
"""
from __future__ import annotations
from sqlalchemy import (
BigInteger,
Boolean,
Column,
DateTime,
ForeignKey,
Index,
Integer,
String,
Text,
func,
text,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
from sqlalchemy.types import JSON
from storage.base import Base
# Use JSON with JSONB variant so models work on SQLite (tests) and PostgreSQL (prod)
_JsonType = JSON().with_variant(JSONB(), 'postgresql')
class Automation(Base): # type: ignore
"""Model for storing automation definitions."""
class Automation(Base):
__tablename__ = 'automations'
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False, index=True)
org_id = Column(String, nullable=True, index=True)
user_id = Column(String, nullable=False)
org_id = Column(String, nullable=True)
name = Column(String, nullable=False)
enabled = Column(Boolean, nullable=False, server_default=text('true'))
config = Column(JSON, nullable=False)
enabled = Column(Boolean, nullable=False, server_default=text('TRUE'))
config = Column(_JsonType, nullable=False)
trigger_type = Column(String, nullable=False)
file_store_key = Column(String, nullable=False)
last_triggered_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
nullable=False,
)
updated_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
onupdate=func.now(),
nullable=False,
)
runs = relationship('AutomationRun', back_populates='automation')
runs = relationship('AutomationRun', back_populates='automation', cascade='all, delete-orphan')
__table_args__ = (
Index('ix_automations_user_id', 'user_id'),
Index('ix_automations_org_id', 'org_id'),
Index('ix_automations_enabled_trigger', 'enabled', 'trigger_type'),
)
class AutomationRun(Base):
class AutomationRun(Base): # type: ignore
"""Model for storing automation run records."""
__tablename__ = 'automation_runs'
id = Column(String, primary_key=True)
automation_id = Column(
String, ForeignKey('automations.id', ondelete='CASCADE'), nullable=False
String,
ForeignKey('automations.id', ondelete='CASCADE'),
nullable=False,
)
event_id = Column(
BigInteger,
ForeignKey('automation_events.id'),
nullable=True,
)
event_id = Column(Integer, ForeignKey('automation_events.id'), nullable=True)
conversation_id = Column(String, nullable=True)
status = Column(String, nullable=False, server_default=text("'PENDING'"))
claimed_by = Column(String, nullable=True)
@@ -64,14 +81,29 @@ class AutomationRun(Base):
retry_count = Column(Integer, nullable=False, server_default=text('0'))
max_retries = Column(Integer, nullable=False, server_default=text('3'))
next_retry_at = Column(DateTime(timezone=True), nullable=True)
event_payload = Column(JSON, nullable=True)
event_payload = Column(_JsonType, nullable=True)
error_detail = Column(Text, nullable=True)
started_at = Column(DateTime(timezone=True), nullable=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
nullable=False,
)
automation = relationship('Automation', back_populates='runs')
__table_args__ = (
Index(
'ix_automation_runs_claimable',
'status',
'next_retry_at',
postgresql_where=text("status = 'PENDING' AND (next_retry_at IS NULL OR next_retry_at <= now())"),
),
Index('ix_automation_runs_automation_id', 'automation_id'),
Index(
'ix_automation_runs_heartbeat',
'heartbeat_at',
postgresql_where=text("status = 'RUNNING'"),
),
)

View File

@@ -1,27 +1,37 @@
"""SQLAlchemy model for automation events (the inbox).
from __future__ import annotations
Stub for Task 1 (Data Foundation). This model will be replaced when Task 1
is merged into automations-phase1.
"""
from sqlalchemy import Column, DateTime, Integer, String, Text, text
from sqlalchemy import BigInteger, Column, DateTime, Index, String, Text, UniqueConstraint, text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.types import JSON
from storage.base import Base
_JsonType = JSON().with_variant(JSONB(), 'postgresql')
class AutomationEvent(Base): # type: ignore
"""Model for storing raw automation trigger events."""
class AutomationEvent(Base):
__tablename__ = 'automation_events'
id = Column(Integer, primary_key=True, autoincrement=True)
id = Column(BigInteger, primary_key=True, autoincrement=True)
source_type = Column(String, nullable=False)
payload = Column(JSON, nullable=False)
metadata_ = Column('metadata', JSON, nullable=True)
payload = Column(_JsonType, nullable=False)
metadata_ = Column('metadata', _JsonType, nullable=True)
dedup_key = Column(String, nullable=False, unique=True)
status = Column(String, nullable=False, server_default=text("'NEW'"))
error_detail = Column(Text, nullable=True)
created_at = Column(
DateTime(timezone=True),
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
nullable=False,
)
processed_at = Column(DateTime(timezone=True), nullable=True)
__table_args__ = (
UniqueConstraint('dedup_key', name='uq_automation_events_dedup'),
Index(
'ix_automation_events_new',
'created_at',
postgresql_where=text("status = 'NEW'"),
),
)

View File

@@ -1,68 +0,0 @@
"""Shared fixtures for services tests.
Note: We pre-load ``storage`` as a namespace package to avoid the heavy
``storage/__init__.py`` that imports the entire enterprise model graph.
This must happen *before* any ``from storage.…`` import.
"""
import contextlib
import sys
import types
# Prevent storage/__init__.py from loading the full model graph.
# We only need the lightweight automation models for these tests.
if 'storage' not in sys.modules:
import pathlib
_storage_dir = str(pathlib.Path(__file__).resolve().parents[3] / 'storage')
_mod = types.ModuleType('storage')
_mod.__path__ = [_storage_dir]
sys.modules['storage'] = _mod
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.automation import Automation, AutomationRun # noqa: F401
from storage.automation_event import AutomationEvent # noqa: F401
from storage.base import Base
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session_factory(async_engine):
"""Create an async session factory that yields context-managed sessions."""
factory = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
@contextlib.asynccontextmanager
async def _session_ctx():
async with factory() as session:
yield session
return _session_ctx
@pytest.fixture
async def async_session(async_session_factory):
"""Create a single async session for testing."""
async with async_session_factory() as session:
yield session

View File

@@ -0,0 +1,160 @@
"""Tests for automation config extraction and validation."""
import pytest
from pydantic import ValidationError
from services.automation_config import extract_config, validate_config
# ---------------------------------------------------------------------------
# extract_config
# ---------------------------------------------------------------------------
class TestExtractConfig:
def test_plain_dict(self):
source = '''
__config__ = {
"name": "Daily Report",
"triggers": {"cron": {"schedule": "0 9 * * 1"}},
}
'''
cfg = extract_config(source)
assert cfg['name'] == 'Daily Report'
assert cfg['triggers']['cron']['schedule'] == '0 9 * * 1'
def test_annotated_assignment(self):
source = '''
__config__: dict = {
"name": "Annotated",
"triggers": {"cron": {"schedule": "*/5 * * * *"}},
"description": "annotated config",
}
'''
cfg = extract_config(source)
assert cfg['name'] == 'Annotated'
assert cfg['description'] == 'annotated config'
def test_no_config_raises(self):
source = 'x = 1\ny = 2\n'
with pytest.raises(ValueError, match='__config__ not found'):
extract_config(source)
def test_non_literal_value_raises(self):
source = '__config__ = some_function()\n'
with pytest.raises(ValueError, match='literal expression'):
extract_config(source)
def test_bad_syntax_raises(self):
source = 'def foo(\n'
with pytest.raises(ValueError, match='Failed to parse'):
extract_config(source)
def test_config_not_dict_raises(self):
source = '__config__ = [1, 2, 3]\n'
with pytest.raises(ValueError, match='must be a dict'):
extract_config(source)
def test_config_with_surrounding_code(self):
source = '''
import os
__config__ = {"name": "Mixed", "triggers": {"cron": {"schedule": "0 0 * * *"}}}
def main():
pass
'''
cfg = extract_config(source)
assert cfg['name'] == 'Mixed'
# ---------------------------------------------------------------------------
# validate_config
# ---------------------------------------------------------------------------
class TestValidateConfig:
def test_valid_cron_config(self):
cfg = {
'name': 'My Automation',
'triggers': {
'cron': {'schedule': '0 9 * * 1', 'timezone': 'America/New_York'},
},
}
model = validate_config(cfg)
assert model.name == 'My Automation'
assert model.triggers.cron is not None
assert model.triggers.cron.schedule == '0 9 * * 1'
assert model.triggers.cron.timezone == 'America/New_York'
def test_valid_cron_default_timezone(self):
cfg = {
'name': 'Simple',
'triggers': {'cron': {'schedule': '*/5 * * * *'}},
}
model = validate_config(cfg)
assert model.triggers.cron is not None
assert model.triggers.cron.timezone == 'UTC'
def test_valid_with_description(self):
cfg = {
'name': 'Described',
'triggers': {'cron': {'schedule': '0 0 * * *'}},
'description': 'A helpful description',
}
model = validate_config(cfg)
assert model.description == 'A helpful description'
def test_missing_name_raises(self):
cfg = {
'triggers': {'cron': {'schedule': '0 0 * * *'}},
}
with pytest.raises(ValidationError):
validate_config(cfg)
def test_empty_name_raises(self):
cfg = {
'name': '',
'triggers': {'cron': {'schedule': '0 0 * * *'}},
}
with pytest.raises(ValidationError):
validate_config(cfg)
def test_name_too_long_raises(self):
cfg = {
'name': 'x' * 201,
'triggers': {'cron': {'schedule': '0 0 * * *'}},
}
with pytest.raises(ValidationError):
validate_config(cfg)
def test_missing_triggers_raises(self):
cfg = {'name': 'No Triggers'}
with pytest.raises(ValidationError):
validate_config(cfg)
def test_empty_triggers_raises(self):
cfg = {'name': 'Empty', 'triggers': {}}
with pytest.raises(ValidationError, match='Exactly one trigger'):
validate_config(cfg)
def test_invalid_cron_expression_raises(self):
cfg = {
'name': 'Bad Cron',
'triggers': {'cron': {'schedule': 'not-a-cron'}},
}
with pytest.raises(ValidationError, match='Invalid cron expression'):
validate_config(cfg)
def test_invalid_cron_too_few_fields(self):
cfg = {
'name': 'Short Cron',
'triggers': {'cron': {'schedule': '* *'}},
}
with pytest.raises(ValidationError, match='Invalid cron expression'):
validate_config(cfg)
def test_name_at_boundary_200(self):
cfg = {
'name': 'x' * 200,
'triggers': {'cron': {'schedule': '0 0 * * *'}},
}
model = validate_config(cfg)
assert len(model.name) == 200

View File

@@ -0,0 +1,70 @@
"""Tests for automation event publisher."""
from __future__ import annotations
from unittest.mock import MagicMock
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from services.automation_event_publisher import pg_notify_new_event, publish_automation_event
from storage.automation_event import AutomationEvent
from storage.base import Base
def _make_engine():
engine = create_engine('sqlite://', connect_args={'check_same_thread': False})
Base.metadata.create_all(engine)
return engine
class TestPublishAutomationEvent:
def test_creates_event_with_correct_fields(self):
engine = _make_engine()
Session = sessionmaker(bind=engine)
with Session() as session:
event = publish_automation_event(
session=session,
source_type='cron',
payload={'automation_id': 'abc'},
dedup_key='cron-abc-2025',
metadata={'extra': 'data'},
)
session.commit()
fetched = session.get(AutomationEvent, event.id)
assert fetched is not None
assert fetched.source_type == 'cron'
assert fetched.payload == {'automation_id': 'abc'}
assert fetched.dedup_key == 'cron-abc-2025'
assert fetched.metadata_ == {'extra': 'data'}
assert fetched.status == 'NEW'
def test_creates_event_without_metadata(self):
engine = _make_engine()
Session = sessionmaker(bind=engine)
with Session() as session:
event = publish_automation_event(
session=session,
source_type='manual',
payload={'test': True},
dedup_key='manual-123',
)
session.commit()
fetched = session.get(AutomationEvent, event.id)
assert fetched is not None
assert fetched.metadata_ is None
class TestPgNotifyNewEvent:
def test_pg_notify_executes_sql(self):
"""pg_notify uses PostgreSQL-specific function; verify it at least
constructs the correct SQL statement. On SQLite this will fail at
execution, so we just verify the function doesn't error before execute."""
mock_session = MagicMock()
pg_notify_new_event(mock_session, 42)
mock_session.execute.assert_called_once()
call_args = mock_session.execute.call_args
sql_text = str(call_args[0][0])
assert 'pg_notify' in sql_text

View File

@@ -1,624 +0,0 @@
"""Tests for the automation executor.
Uses real SQLite database operations for event processing, run claiming,
and stale run recovery. HTTP calls to the V1 API are mocked.
"""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
from services.automation_executor import (
_mark_run_failed,
claim_and_execute_runs,
find_matching_automations,
is_terminal,
process_new_events,
recover_stale_runs,
utc_now,
)
from sqlalchemy import select
from storage.automation import Automation, AutomationRun
from storage.automation_event import AutomationEvent
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_automation(
automation_id: str = 'auto-1',
user_id: str = 'user-1',
enabled: bool = True,
trigger_type: str = 'cron',
name: str = 'Test Automation',
) -> Automation:
return Automation(
id=automation_id,
user_id=user_id,
org_id='org-1',
name=name,
enabled=enabled,
config={'triggers': {'cron': {'schedule': '0 9 * * 5'}}},
trigger_type=trigger_type,
file_store_key=f'automations/{automation_id}/script.py',
)
def make_event(
source_type: str = 'cron',
payload: dict | None = None,
status: str = 'NEW',
dedup_key: str | None = None,
) -> AutomationEvent:
return AutomationEvent(
source_type=source_type,
payload=payload or {'automation_id': 'auto-1'},
dedup_key=dedup_key or f'dedup-{uuid4().hex[:8]}',
status=status,
created_at=utc_now(),
)
def make_run(
run_id: str | None = None,
automation_id: str = 'auto-1',
status: str = 'PENDING',
claimed_by: str | None = None,
heartbeat_at: datetime | None = None,
retry_count: int = 0,
max_retries: int = 3,
next_retry_at: datetime | None = None,
) -> AutomationRun:
return AutomationRun(
id=run_id or uuid4().hex,
automation_id=automation_id,
status=status,
claimed_by=claimed_by,
heartbeat_at=heartbeat_at,
retry_count=retry_count,
max_retries=max_retries,
next_retry_at=next_retry_at,
event_payload={'automation_id': automation_id},
created_at=utc_now(),
)
# ---------------------------------------------------------------------------
# find_matching_automations
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_find_matching_automations_cron_event(async_session):
"""Cron events match by automation_id in payload."""
automation = make_automation()
async_session.add(automation)
await async_session.commit()
event = make_event(source_type='cron', payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 1
assert result[0].id == 'auto-1'
@pytest.mark.asyncio
async def test_find_matching_automations_manual_event(async_session):
"""Manual events also match by automation_id in payload."""
automation = make_automation()
async_session.add(automation)
await async_session.commit()
event = make_event(source_type='manual', payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 1
assert result[0].id == 'auto-1'
@pytest.mark.asyncio
async def test_find_matching_automations_disabled_automation(async_session):
"""Disabled automations are not matched."""
automation = make_automation(enabled=False)
async_session.add(automation)
await async_session.commit()
event = make_event(payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
@pytest.mark.asyncio
async def test_find_matching_automations_missing_automation_id(async_session):
"""Events without automation_id in payload return empty list."""
event = make_event(payload={'something_else': 'value'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
@pytest.mark.asyncio
async def test_find_matching_automations_nonexistent_automation(async_session):
"""Events referencing a non-existent automation return empty list."""
event = make_event(payload={'automation_id': 'nonexistent'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
@pytest.mark.asyncio
async def test_find_matching_automations_unknown_source_type(async_session):
"""Unknown source types return empty list."""
event = make_event(source_type='unknown', payload={'automation_id': 'auto-1'})
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert len(result) == 0
# ---------------------------------------------------------------------------
# process_new_events
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_process_new_events_creates_runs(async_session):
"""Processing NEW events creates PENDING runs and marks events PROCESSED."""
automation = make_automation()
event = make_event(payload={'automation_id': 'auto-1'})
async_session.add_all([automation, event])
await async_session.commit()
count = await process_new_events(async_session)
assert count == 1
# Event should be PROCESSED
await async_session.refresh(event)
assert event.status == 'PROCESSED'
assert event.processed_at is not None
# A run should have been created
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 1
assert runs[0].automation_id == 'auto-1'
assert runs[0].status == 'PENDING'
assert runs[0].event_payload == {'automation_id': 'auto-1'}
@pytest.mark.asyncio
async def test_process_new_events_no_match(async_session):
"""Events with no matching automation are marked NO_MATCH."""
event = make_event(payload={'automation_id': 'nonexistent'})
async_session.add(event)
await async_session.commit()
count = await process_new_events(async_session)
assert count == 1
await async_session.refresh(event)
assert event.status == 'NO_MATCH'
assert event.processed_at is not None
# No runs created
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 0
@pytest.mark.asyncio
async def test_process_new_events_skips_processed(async_session):
"""Already processed events are not re-processed."""
event = make_event(status='PROCESSED')
async_session.add(event)
await async_session.commit()
count = await process_new_events(async_session)
assert count == 0
@pytest.mark.asyncio
async def test_process_new_events_multiple_events(async_session):
"""Multiple NEW events are processed in one batch."""
auto1 = make_automation(automation_id='auto-1')
auto2 = make_automation(automation_id='auto-2', name='Auto 2')
event1 = make_event(payload={'automation_id': 'auto-1'}, dedup_key='dedup-1')
event2 = make_event(payload={'automation_id': 'auto-2'}, dedup_key='dedup-2')
event3 = make_event(payload={'automation_id': 'nonexistent'}, dedup_key='dedup-3')
async_session.add_all([auto1, auto2, event1, event2, event3])
await async_session.commit()
count = await process_new_events(async_session)
assert count == 3
# Two runs created (for auto-1 and auto-2), none for nonexistent
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 2
await async_session.refresh(event1)
await async_session.refresh(event2)
await async_session.refresh(event3)
assert event1.status == 'PROCESSED'
assert event2.status == 'PROCESSED'
assert event3.status == 'NO_MATCH'
# ---------------------------------------------------------------------------
# claim_and_execute_runs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_claim_and_execute_runs_claims_pending(
async_session, async_session_factory
):
"""Claims a PENDING run and transitions to RUNNING."""
automation = make_automation()
run = make_run(run_id='run-1')
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is True
await async_session.refresh(run)
assert run.status == 'RUNNING'
assert run.claimed_by == 'executor-test-1'
assert run.claimed_at is not None
assert run.heartbeat_at is not None
assert run.started_at is not None
@pytest.mark.asyncio
async def test_claim_and_execute_runs_no_pending(async_session, async_session_factory):
"""Returns False when no PENDING runs exist."""
api_client = AsyncMock()
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is False
@pytest.mark.asyncio
async def test_claim_and_execute_runs_respects_next_retry_at(
async_session, async_session_factory
):
"""Runs with future next_retry_at are not claimed."""
automation = make_automation()
run = make_run(
run_id='run-retry',
next_retry_at=utc_now() + timedelta(hours=1),
)
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is False
@pytest.mark.asyncio
async def test_claim_and_execute_runs_past_retry_at(
async_session, async_session_factory
):
"""Runs with past next_retry_at are claimable."""
automation = make_automation()
run = make_run(
run_id='run-retry-past',
next_retry_at=utc_now() - timedelta(minutes=5),
)
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is True
@pytest.mark.asyncio
async def test_claim_skips_running_runs(async_session, async_session_factory):
"""RUNNING runs are not claimed."""
automation = make_automation()
run = make_run(run_id='run-running', status='RUNNING', claimed_by='other-executor')
async_session.add_all([automation, run])
await async_session.commit()
api_client = AsyncMock()
claimed = await claim_and_execute_runs(
async_session, 'executor-test-1', api_client, async_session_factory
)
assert claimed is False
# ---------------------------------------------------------------------------
# recover_stale_runs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_recover_stale_runs_recovers_stale(async_session):
"""RUNNING runs with expired heartbeats are recovered to PENDING."""
automation = make_automation()
stale_run = make_run(
run_id='stale-1',
status='RUNNING',
claimed_by='crashed-executor',
heartbeat_at=utc_now() - timedelta(minutes=10),
retry_count=0,
)
async_session.add_all([automation, stale_run])
await async_session.commit()
count = await recover_stale_runs(async_session)
assert count >= 1
await async_session.refresh(stale_run)
assert stale_run.status == 'PENDING'
assert stale_run.claimed_by is None
assert stale_run.retry_count == 1
assert stale_run.next_retry_at is not None
@pytest.mark.asyncio
async def test_recover_stale_runs_ignores_fresh(async_session):
"""RUNNING runs with recent heartbeats are not recovered."""
automation = make_automation()
fresh_run = make_run(
run_id='fresh-1',
status='RUNNING',
claimed_by='active-executor',
heartbeat_at=utc_now() - timedelta(seconds=30),
)
async_session.add_all([automation, fresh_run])
await async_session.commit()
count = await recover_stale_runs(async_session)
assert count == 0
await async_session.refresh(fresh_run)
assert fresh_run.status == 'RUNNING'
assert fresh_run.claimed_by == 'active-executor'
@pytest.mark.asyncio
async def test_recover_stale_runs_ignores_pending(async_session):
"""PENDING runs are not affected by recovery."""
automation = make_automation()
pending_run = make_run(run_id='pending-1', status='PENDING')
async_session.add_all([automation, pending_run])
await async_session.commit()
count = await recover_stale_runs(async_session)
assert count == 0
await async_session.refresh(pending_run)
assert pending_run.status == 'PENDING'
@pytest.mark.asyncio
async def test_recover_stale_runs_increments_retry_count(async_session):
"""Recovery increments the retry_count."""
automation = make_automation()
stale_run = make_run(
run_id='stale-retry',
status='RUNNING',
claimed_by='old-executor',
heartbeat_at=utc_now() - timedelta(minutes=10),
retry_count=2,
)
async_session.add_all([automation, stale_run])
await async_session.commit()
await recover_stale_runs(async_session)
await async_session.refresh(stale_run)
assert stale_run.retry_count == 3
# ---------------------------------------------------------------------------
# _mark_run_failed (error handling)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mark_run_failed_retries(async_session_factory):
"""Failed runs with retries left return to PENDING."""
async with async_session_factory() as session:
automation = make_automation()
run = make_run(run_id='fail-retry', retry_count=0, max_retries=3)
session.add_all([automation, run])
await session.commit()
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-retry')
await _mark_run_failed(run_obj, 'API error', async_session_factory)
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-retry')
assert run_obj.status == 'PENDING'
assert run_obj.retry_count == 1
assert run_obj.error_detail == 'API error'
assert run_obj.next_retry_at is not None
assert run_obj.claimed_by is None
@pytest.mark.asyncio
async def test_mark_run_failed_dead_letter(async_session_factory):
"""Failed runs that exceed max_retries go to DEAD_LETTER."""
async with async_session_factory() as session:
automation = make_automation()
run = make_run(run_id='fail-dead', retry_count=2, max_retries=3)
session.add_all([automation, run])
await session.commit()
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-dead')
await _mark_run_failed(run_obj, 'Final failure', async_session_factory)
async with async_session_factory() as session:
run_obj = await session.get(AutomationRun, 'fail-dead')
assert run_obj.status == 'DEAD_LETTER'
assert run_obj.retry_count == 3
assert run_obj.error_detail == 'Final failure'
assert run_obj.completed_at is not None
# ---------------------------------------------------------------------------
# is_terminal
# ---------------------------------------------------------------------------
def test_is_terminal_stopped():
assert is_terminal({'status': 'STOPPED'}) is True
def test_is_terminal_error():
assert is_terminal({'status': 'ERROR'}) is True
def test_is_terminal_completed():
assert is_terminal({'status': 'COMPLETED'}) is True
def test_is_terminal_cancelled():
assert is_terminal({'status': 'CANCELLED'}) is True
def test_is_terminal_running():
assert is_terminal({'status': 'RUNNING'}) is False
def test_is_terminal_empty():
assert is_terminal({}) is False
def test_is_terminal_case_insensitive():
assert is_terminal({'status': 'stopped'}) is True
assert is_terminal({'status': 'Completed'}) is True
# ---------------------------------------------------------------------------
# find_matching_automations — None payload
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_find_matching_automations_none_payload(async_session):
"""Events with None payload return empty list (data corruption guard)."""
event = make_event(source_type='cron')
event.payload = None
async_session.add(event)
await async_session.commit()
result = await find_matching_automations(async_session, event)
assert result == []
# ---------------------------------------------------------------------------
# Integration: event → run creation → claim
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_integration_event_to_run_to_claim(
async_session_factory,
):
"""Full flow: create event + automation → process_new_events → claim_and_execute_runs.
Uses a real SQLite database; only the external API client is mocked.
"""
# 1. Seed an automation and a NEW event
async with async_session_factory() as session:
automation = make_automation(automation_id='integ-auto')
event = make_event(
source_type='cron',
payload={'automation_id': 'integ-auto'},
dedup_key='integ-dedup',
)
session.add_all([automation, event])
await session.commit()
event_id = event.id
# 2. Process inbox — should match and create a PENDING run
async with async_session_factory() as session:
processed = await process_new_events(session)
assert processed == 1
# Verify event is PROCESSED and run was created
async with async_session_factory() as session:
evt = await session.get(AutomationEvent, event_id)
assert evt.status == 'PROCESSED'
runs = (await session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 1
run = runs[0]
assert run.automation_id == 'integ-auto'
assert run.status == 'PENDING'
assert run.event_payload == {'automation_id': 'integ-auto'}
# 3. Claim the run — mock execute_run to avoid real API calls
api_client = AsyncMock()
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
async with async_session_factory() as session:
claimed = await claim_and_execute_runs(
session, 'executor-integ', api_client, async_session_factory
)
assert claimed is True
# 4. Verify the run moved to RUNNING with correct executor
async with async_session_factory() as session:
runs = (await session.execute(select(AutomationRun))).scalars().all()
assert len(runs) == 1
run = runs[0]
assert run.status == 'RUNNING'
assert run.claimed_by == 'executor-integ'
assert run.started_at is not None
assert run.heartbeat_at is not None

View File

@@ -0,0 +1,104 @@
"""Tests for automation file generator."""
import ast
from services.automation_config import extract_config, validate_config
from services.automation_file_generator import generate_automation_file
class TestGenerateAutomationFile:
def test_generates_valid_python(self):
source = generate_automation_file(
name='Daily Report',
schedule='0 9 * * 1',
timezone='UTC',
prompt='Generate the daily status report.',
)
# Must parse without error
ast.parse(source)
def test_contains_config(self):
source = generate_automation_file(
name='Test Automation',
schedule='*/5 * * * *',
timezone='America/New_York',
prompt='Do something useful.',
)
cfg = extract_config(source)
assert cfg['name'] == 'Test Automation'
assert cfg['triggers']['cron']['schedule'] == '*/5 * * * *'
assert cfg['triggers']['cron']['timezone'] == 'America/New_York'
def test_round_trip(self):
"""Generate → extract → validate must succeed."""
source = generate_automation_file(
name='Round Trip',
schedule='30 14 * * 0',
timezone='Europe/London',
prompt='Weekly summary please.',
)
cfg = extract_config(source)
model = validate_config(cfg)
assert model.name == 'Round Trip'
assert model.triggers.cron is not None
assert model.triggers.cron.schedule == '30 14 * * 0'
assert model.triggers.cron.timezone == 'Europe/London'
def test_contains_prompt(self):
source = generate_automation_file(
name='Prompt Test',
schedule='0 0 * * *',
timezone='UTC',
prompt='Hello world!',
)
assert 'Hello world!' in source
def test_contains_docstring(self):
source = generate_automation_file(
name='Doc Test',
schedule='0 0 * * *',
timezone='UTC',
prompt='test',
)
assert 'Doc Test' in source
assert 'auto-generated automation' in source
def test_special_characters_in_prompt(self):
source = generate_automation_file(
name='Special Chars',
schedule='0 0 * * *',
timezone='UTC',
prompt='Check the "status" of \\n stuff',
)
# Must still be valid Python
ast.parse(source)
cfg = extract_config(source)
assert cfg['name'] == 'Special Chars'
def test_triple_quotes_in_name(self):
"""Names containing triple quotes must not break the generated file."""
source = generate_automation_file(
name='Test """Demo""" Name',
schedule='0 0 * * *',
timezone='UTC',
prompt='hello',
)
# Must still be valid Python
ast.parse(source)
cfg = extract_config(source)
assert 'Demo' in cfg['name'] # name preserved in config
def test_triple_quotes_in_prompt(self):
"""Prompts containing triple quotes must not break the generated file."""
source = generate_automation_file(
name='Triple Quote Test',
schedule='0 0 * * *',
timezone='UTC',
prompt='Use """triple quotes""" and \'\'\'single triples\'\'\' safely',
)
# Must parse without error
ast.parse(source)
cfg = extract_config(source)
assert cfg['name'] == 'Triple Quote Test'
# The prompt must survive round-trip
assert '"""triple quotes"""' in source or "triple quotes" in source

View File

@@ -1,185 +0,0 @@
"""Tests for OpenHandsAPIClient with mocked HTTP responses."""
import base64
import httpx
import pytest
from services.openhands_api_client import OpenHandsAPIClient
@pytest.fixture
def api_client():
client = OpenHandsAPIClient(base_url='http://test-server:3000')
yield client
# close handled in tests that need it
# ---------------------------------------------------------------------------
# start_conversation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_start_conversation_sends_correct_request(api_client, respx_mock):
"""start_conversation sends properly formatted POST with auth header."""
automation_file = b'print("hello")'
expected_b64 = base64.b64encode(automation_file).decode()
route = respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(
200,
json={
'app_conversation_id': 'conv-123',
'status': 'RUNNING',
},
)
)
result = await api_client.start_conversation(
api_key='sk-oh-test123',
automation_file=automation_file,
title='Test Automation',
event_payload={'automation_id': 'auto-1'},
)
assert route.called
request = route.calls[0].request
assert request.headers['Authorization'] == 'Bearer sk-oh-test123'
import json
body = json.loads(request.content)
assert body['automation_file'] == expected_b64
assert body['trigger'] == 'automation'
assert body['title'] == 'Test Automation'
assert body['event_payload'] == {'automation_id': 'auto-1'}
assert result == {'app_conversation_id': 'conv-123', 'status': 'RUNNING'}
@pytest.mark.asyncio
async def test_start_conversation_without_event_payload(api_client, respx_mock):
"""start_conversation works with event_payload=None."""
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(200, json={'app_conversation_id': 'conv-456'})
)
result = await api_client.start_conversation(
api_key='sk-oh-test',
automation_file=b'code',
title='Test',
event_payload=None,
)
assert result['app_conversation_id'] == 'conv-456'
@pytest.mark.asyncio
async def test_start_conversation_http_error(api_client, respx_mock):
"""start_conversation raises on HTTP errors."""
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(500, json={'error': 'Internal Server Error'})
)
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await api_client.start_conversation(
api_key='sk-oh-test',
automation_file=b'code',
title='Test',
)
assert exc_info.value.response.status_code == 500
@pytest.mark.asyncio
async def test_start_conversation_auth_error(api_client, respx_mock):
"""start_conversation raises on 401 Unauthorized."""
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(401, json={'error': 'Unauthorized'})
)
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await api_client.start_conversation(
api_key='bad-key',
automation_file=b'code',
title='Test',
)
assert exc_info.value.response.status_code == 401
# ---------------------------------------------------------------------------
# get_conversation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_conversation_returns_data(api_client, respx_mock):
"""get_conversation returns the first conversation from the list."""
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(
200,
json=[
{
'conversation_id': 'conv-123',
'status': 'RUNNING',
'title': 'My Automation',
}
],
)
)
result = await api_client.get_conversation('sk-oh-test', 'conv-123')
assert result is not None
assert result['conversation_id'] == 'conv-123'
assert result['status'] == 'RUNNING'
@pytest.mark.asyncio
async def test_get_conversation_returns_none_when_empty(api_client, respx_mock):
"""get_conversation returns None when API returns empty list."""
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(200, json=[])
)
result = await api_client.get_conversation('sk-oh-test', 'nonexistent')
assert result is None
@pytest.mark.asyncio
async def test_get_conversation_sends_auth_header(api_client, respx_mock):
"""get_conversation sends the correct authorization header."""
route = respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(200, json=[])
)
await api_client.get_conversation('sk-oh-mykey', 'conv-1')
assert route.called
request = route.calls[0].request
assert request.headers['Authorization'] == 'Bearer sk-oh-mykey'
@pytest.mark.asyncio
async def test_get_conversation_http_error(api_client, respx_mock):
"""get_conversation raises on HTTP errors."""
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
return_value=httpx.Response(503, text='Service Unavailable')
)
with pytest.raises(httpx.HTTPStatusError):
await api_client.get_conversation('sk-oh-test', 'conv-1')
# ---------------------------------------------------------------------------
# close
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_close(api_client):
"""close() shuts down the HTTP client without errors."""
await api_client.close()

View File

@@ -0,0 +1,184 @@
"""Tests for Automation and AutomationRun SQLAlchemy models."""
import uuid
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from storage.automation import Automation, AutomationRun
from storage.automation_event import AutomationEvent
from storage.base import Base
def _make_engine():
engine = create_engine('sqlite://', connect_args={'check_same_thread': False})
Base.metadata.create_all(engine)
return engine
class TestAutomationModel:
def test_create_automation_with_defaults(self):
engine = _make_engine()
Session = sessionmaker(bind=engine)
with Session() as session:
auto = Automation(
id=uuid.uuid4().hex,
user_id='user-1',
name='My Automation',
config={'name': 'My Automation', 'triggers': {'cron': {'schedule': '0 9 * * 1'}}},
trigger_type='cron',
file_store_key='automations/user-1/abc.py',
)
session.add(auto)
session.commit()
fetched = session.get(Automation, auto.id)
assert fetched is not None
assert fetched.name == 'My Automation'
assert fetched.user_id == 'user-1'
assert fetched.trigger_type == 'cron'
assert fetched.org_id is None
assert fetched.last_triggered_at is None
def test_automation_with_org_id(self):
engine = _make_engine()
Session = sessionmaker(bind=engine)
with Session() as session:
auto = Automation(
id=uuid.uuid4().hex,
user_id='user-1',
org_id='org-123',
name='Org Automation',
config={'name': 'Org Automation'},
trigger_type='cron',
file_store_key='automations/user-1/xyz.py',
)
session.add(auto)
session.commit()
fetched = session.get(Automation, auto.id)
assert fetched is not None
assert fetched.org_id == 'org-123'
class TestAutomationRunModel:
def test_create_run_with_defaults(self):
engine = _make_engine()
Session = sessionmaker(bind=engine)
with Session() as session:
auto = Automation(
id=uuid.uuid4().hex,
user_id='user-1',
name='Test',
config={},
trigger_type='cron',
file_store_key='test.py',
)
session.add(auto)
session.flush()
run = AutomationRun(
id=uuid.uuid4().hex,
automation_id=auto.id,
)
session.add(run)
session.commit()
fetched = session.get(AutomationRun, run.id)
assert fetched is not None
assert fetched.automation_id == auto.id
assert fetched.conversation_id is None
assert fetched.event_id is None
assert fetched.claimed_by is None
assert fetched.error_detail is None
def test_run_relationship(self):
engine = _make_engine()
Session = sessionmaker(bind=engine)
with Session() as session:
auto = Automation(
id=uuid.uuid4().hex,
user_id='user-1',
name='Test',
config={},
trigger_type='cron',
file_store_key='test.py',
)
session.add(auto)
session.flush()
run = AutomationRun(
id=uuid.uuid4().hex,
automation_id=auto.id,
)
session.add(run)
session.commit()
session.refresh(auto)
assert len(auto.runs) == 1
assert auto.runs[0].id == run.id
assert run.automation.id == auto.id
class TestAutomationEventModel:
def test_create_event(self):
engine = _make_engine()
Session = sessionmaker(bind=engine)
with Session() as session:
event = AutomationEvent(
source_type='cron',
payload={'tick': True},
dedup_key='cron-2025-01-01T00:00:00',
)
session.add(event)
session.commit()
fetched = session.get(AutomationEvent, event.id)
assert fetched is not None
assert fetched.source_type == 'cron'
assert fetched.payload == {'tick': True}
assert fetched.dedup_key == 'cron-2025-01-01T00:00:00'
assert fetched.error_detail is None
def test_event_with_metadata(self):
engine = _make_engine()
Session = sessionmaker(bind=engine)
with Session() as session:
event = AutomationEvent(
source_type='github',
payload={'action': 'opened'},
dedup_key='gh-12345',
metadata_={'installation_id': 99},
)
session.add(event)
session.commit()
fetched = session.get(AutomationEvent, event.id)
assert fetched is not None
assert fetched.metadata_ == {'installation_id': 99}
def test_dedup_key_unique(self):
engine = _make_engine()
Session = sessionmaker(bind=engine)
import sqlalchemy
with Session() as session:
e1 = AutomationEvent(
source_type='cron',
payload={},
dedup_key='dup-key',
)
session.add(e1)
session.commit()
with Session() as session:
e2 = AutomationEvent(
source_type='cron',
payload={},
dedup_key='dup-key',
)
session.add(e2)
try:
session.commit()
assert False, 'Expected IntegrityError'
except sqlalchemy.exc.IntegrityError:
session.rollback()

View File

@@ -16,6 +16,7 @@ class ConversationTrigger(Enum):
JIRA_DC = 'jira_dc'
LINEAR = 'linear'
BITBUCKET = 'bitbucket'
AUTOMATION = 'automation'
@dataclass