mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
6 Commits
auto/execu
...
auto/data-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
215e769735 | ||
|
|
03a816333b | ||
|
|
54433c5dae | ||
|
|
9269d045c8 | ||
|
|
29d8990263 | ||
|
|
0c9af8290f |
133
enterprise/migrations/versions/100_create_automation_tables.py
Normal file
133
enterprise/migrations/versions/100_create_automation_tables.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Create automation tables (automations, automation_events, automation_runs)
|
||||
|
||||
Revision ID: 100
|
||||
Revises: 099
|
||||
Create Date: 2025-03-10 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '100'
|
||||
down_revision: Union[str, None] = '099'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# --- automation_events (must come first, referenced by automation_runs) ---
|
||||
op.create_table(
|
||||
'automation_events',
|
||||
sa.Column('id', sa.BigInteger(), sa.Identity(), nullable=False, primary_key=True),
|
||||
sa.Column('source_type', sa.String(), nullable=False),
|
||||
sa.Column('payload', sa.JSON(), nullable=False),
|
||||
sa.Column('metadata', sa.JSON(), nullable=True),
|
||||
sa.Column('dedup_key', sa.String(), nullable=False),
|
||||
sa.Column('status', sa.String(), nullable=False, server_default=sa.text("'NEW'")),
|
||||
sa.Column('error_detail', sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('processed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('dedup_key', name='uq_automation_events_dedup'),
|
||||
)
|
||||
op.create_index(
|
||||
'ix_automation_events_new',
|
||||
'automation_events',
|
||||
['created_at'],
|
||||
postgresql_where=sa.text("status = 'NEW'"),
|
||||
)
|
||||
|
||||
# --- automations ---
|
||||
op.create_table(
|
||||
'automations',
|
||||
sa.Column('id', sa.String(), nullable=False, primary_key=True),
|
||||
sa.Column('user_id', sa.String(), nullable=False),
|
||||
sa.Column('org_id', sa.String(), nullable=True),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('enabled', sa.Boolean(), nullable=False, server_default=sa.text('TRUE')),
|
||||
sa.Column('config', sa.JSON(), nullable=False),
|
||||
sa.Column('trigger_type', sa.String(), nullable=False),
|
||||
sa.Column('file_store_key', sa.String(), nullable=False),
|
||||
sa.Column('last_triggered_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
op.create_index('ix_automations_user_id', 'automations', ['user_id'])
|
||||
op.create_index('ix_automations_org_id', 'automations', ['org_id'])
|
||||
op.create_index('ix_automations_enabled_trigger', 'automations', ['enabled', 'trigger_type'])
|
||||
|
||||
# --- automation_runs ---
|
||||
op.create_table(
|
||||
'automation_runs',
|
||||
sa.Column('id', sa.String(), nullable=False, primary_key=True),
|
||||
sa.Column(
|
||||
'automation_id',
|
||||
sa.String(),
|
||||
sa.ForeignKey('automations.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
'event_id',
|
||||
sa.BigInteger(),
|
||||
sa.ForeignKey('automation_events.id'),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column('conversation_id', sa.String(), nullable=True),
|
||||
sa.Column('status', sa.String(), nullable=False, server_default=sa.text("'PENDING'")),
|
||||
sa.Column('claimed_by', sa.String(), nullable=True),
|
||||
sa.Column('claimed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('heartbeat_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('retry_count', sa.Integer(), nullable=False, server_default=sa.text('0')),
|
||||
sa.Column('max_retries', sa.Integer(), nullable=False, server_default=sa.text('3')),
|
||||
sa.Column('next_retry_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('event_payload', sa.JSON(), nullable=True),
|
||||
sa.Column('error_detail', sa.Text(), nullable=True),
|
||||
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
op.create_index('ix_automation_runs_automation_id', 'automation_runs', ['automation_id'])
|
||||
op.create_index(
|
||||
'ix_automation_runs_claimable',
|
||||
'automation_runs',
|
||||
['status', 'next_retry_at'],
|
||||
postgresql_where=sa.text("status = 'PENDING' AND (next_retry_at IS NULL OR next_retry_at <= now())"),
|
||||
)
|
||||
op.create_index(
|
||||
'ix_automation_runs_heartbeat',
|
||||
'automation_runs',
|
||||
['heartbeat_at'],
|
||||
postgresql_where=sa.text("status = 'RUNNING'"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('automation_runs')
|
||||
op.drop_table('automations')
|
||||
op.drop_table('automation_events')
|
||||
36
enterprise/poetry.lock
generated
36
enterprise/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "agent-client-protocol"
|
||||
@@ -1641,6 +1641,22 @@ files = [
|
||||
{file = "crashtest-0.4.1.tar.gz", hash = "sha256:80d7b1f316ebfbd429f648076d6275c877ba30ba48979de4191714a75266f0ce"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "croniter"
|
||||
version = "6.0.0"
|
||||
description = "croniter provides iteration for datetime object with cron like format"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.6"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "croniter-6.0.0-py2.py3-none-any.whl", hash = "sha256:2f878c3856f17896979b2a4379ba1f09c83e374931ea15cc835c5dd2eee9b368"},
|
||||
{file = "croniter-6.0.0.tar.gz", hash = "sha256:37c504b313956114a983ece2c2b07790b1f1094fe9d81cc94739214748255577"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
python-dateutil = "*"
|
||||
pytz = ">2021.1"
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "46.0.5"
|
||||
@@ -3501,7 +3517,7 @@ files = [
|
||||
[package.dependencies]
|
||||
googleapis-common-protos = ">=1.5.5"
|
||||
grpcio = ">=1.71.2"
|
||||
protobuf = ">=5.26.1,<6.0dev"
|
||||
protobuf = ">=5.26.1,<6.0.dev0"
|
||||
|
||||
[[package]]
|
||||
name = "gspread"
|
||||
@@ -3819,7 +3835,7 @@ pfzy = ">=0.3.1,<0.4.0"
|
||||
prompt-toolkit = ">=3.0.1,<4.0.0"
|
||||
|
||||
[package.extras]
|
||||
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17-beta.43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
|
||||
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17b43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "installer"
|
||||
@@ -4258,7 +4274,7 @@ fqdn = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
idna = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
isoduration = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
jsonpointer = {version = ">1.13", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
jsonschema-specifications = ">=2023.03.6"
|
||||
jsonschema-specifications = ">=2023.3.6"
|
||||
referencing = ">=0.28.4"
|
||||
rfc3339-validator = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
rfc3986-validator = {version = ">0.1.0", optional = true, markers = "extra == \"format-nongpl\""}
|
||||
@@ -4648,7 +4664,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = ">=14.05.14"
|
||||
certifi = ">=14.5.14"
|
||||
durationpy = ">=0.7"
|
||||
google-auth = ">=1.0.1"
|
||||
oauthlib = ">=3.2.2"
|
||||
@@ -6889,7 +6905,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17-beta.43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
|
||||
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17b43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pg8000"
|
||||
@@ -12866,10 +12882,10 @@ files = [
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
botocore = ">=1.37.4,<2.0a.0"
|
||||
botocore = ">=1.37.4,<2.0a0"
|
||||
|
||||
[package.extras]
|
||||
crt = ["botocore[crt] (>=1.37.4,<2.0a.0)"]
|
||||
crt = ["botocore[crt] (>=1.37.4,<2.0a0)"]
|
||||
|
||||
[[package]]
|
||||
name = "scantree"
|
||||
@@ -15006,9 +15022,9 @@ files = [
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and python_version < \"3.14\"", "cffi (>=2.0.0b) ; platform_python_implementation != \"PyPy\" and python_version >= \"3.14\""]
|
||||
cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and python_version < \"3.14\"", "cffi (>=2.0.0b0) ; platform_python_implementation != \"PyPy\" and python_version >= \"3.14\""]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "ef037f6d6085d26166d35c56ce266439f8f1a4fea90bc43ccf15cfeaf116cae5"
|
||||
content-hash = "d07fdf5fbc8eaf4ed30c119b0c05081d0eb3df60732e530002db5eec84ef080b"
|
||||
|
||||
@@ -17,6 +17,7 @@ packages = [
|
||||
{ include = "storage" },
|
||||
{ include = "sync" },
|
||||
{ include = "integrations" },
|
||||
{ include = "services" },
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
@@ -49,6 +50,7 @@ pandas = "^2.2.0"
|
||||
numpy = "^2.2.0"
|
||||
mcp = "^1.10.0"
|
||||
pillow = "^12.1.1"
|
||||
croniter = "^6.0.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.8.3"
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
"""Entry point for the automation executor.
|
||||
|
||||
Usage: python -m run_automation_executor
|
||||
|
||||
This runs as a Kubernetes Deployment (long-running). It polls the automation_events
|
||||
inbox, matches events to automations, claims and executes runs, and monitors
|
||||
conversation completion.
|
||||
|
||||
Environment variables:
|
||||
OPENHANDS_API_URL Base URL for the V1 API (default: http://openhands-service:3000)
|
||||
MAX_CONCURRENT_RUNS Max concurrent runs per executor (default: 5)
|
||||
RUN_TIMEOUT_SECONDS Max time for a single run (default: 7200)
|
||||
POLL_INTERVAL_SECONDS Fallback poll interval (default: 30)
|
||||
HEARTBEAT_INTERVAL_SECONDS Heartbeat update interval (default: 60)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger('saas.automation.executor')
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
"""Configure logging, deferring to enterprise logger if available."""
|
||||
try:
|
||||
from server.logger import setup_all_loggers
|
||||
|
||||
setup_all_loggers()
|
||||
except ImportError:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s %(name)s %(levelname)s %(message)s',
|
||||
stream=sys.stdout,
|
||||
)
|
||||
|
||||
|
||||
def _install_signal_handlers(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Install signal handlers for graceful shutdown."""
|
||||
from services.automation_executor import request_shutdown
|
||||
|
||||
def _handle_signal(signum: int, _frame: object) -> None:
|
||||
sig_name = signal.Signals(signum).name
|
||||
logger.info('Received %s, initiating graceful shutdown...', sig_name)
|
||||
request_shutdown()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
signal.signal(sig, _handle_signal)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
from services.automation_executor import executor_main
|
||||
|
||||
await executor_main()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_setup_logging()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
_install_signal_handlers(loop)
|
||||
|
||||
logger.info('Starting automation executor')
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info('Interrupted by user')
|
||||
finally:
|
||||
loop.close()
|
||||
logger.info('Automation executor process exiting')
|
||||
121
enterprise/services/automation_config.py
Normal file
121
enterprise/services/automation_config.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Automation config extraction and validation.
|
||||
|
||||
Parses ``__config__`` from automation Python source files
|
||||
and validates it against a Pydantic schema.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from croniter import croniter
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
|
||||
def extract_config(source: str) -> dict:
|
||||
"""Extract the ``__config__`` dict from automation Python source code.
|
||||
|
||||
Uses :mod:`ast` to safely parse the source and locate a module-level
|
||||
assignment to ``__config__``. The value must be a literal expression
|
||||
(evaluated via :func:`ast.literal_eval`).
|
||||
|
||||
Raises:
|
||||
ValueError: If ``__config__`` is not found or its value contains
|
||||
non-literal expressions.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
except SyntaxError as exc:
|
||||
raise ValueError(f'Failed to parse source: {exc}') from exc
|
||||
|
||||
for node in ast.iter_child_nodes(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == '__config__':
|
||||
try:
|
||||
value = ast.literal_eval(node.value)
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise ValueError(
|
||||
f'__config__ value must be a literal expression: {exc}'
|
||||
) from exc
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError('__config__ must be a dict')
|
||||
return value
|
||||
# Handle annotated assignment: __config__: dict = {...}
|
||||
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
|
||||
if node.target.id == '__config__' and node.value is not None:
|
||||
try:
|
||||
value = ast.literal_eval(node.value)
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise ValueError(
|
||||
f'__config__ value must be a literal expression: {exc}'
|
||||
) from exc
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError('__config__ must be a dict')
|
||||
return value
|
||||
|
||||
raise ValueError('__config__ not found in source')
|
||||
|
||||
|
||||
class CronTriggerModel(BaseModel):
|
||||
"""Cron trigger configuration."""
|
||||
|
||||
schedule: str
|
||||
timezone: str = 'UTC'
|
||||
|
||||
@field_validator('schedule')
|
||||
@classmethod
|
||||
def validate_schedule(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not croniter.is_valid(v):
|
||||
raise ValueError(f'Invalid cron expression: {v!r}')
|
||||
return v
|
||||
|
||||
@field_validator('timezone')
|
||||
@classmethod
|
||||
def validate_timezone(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
try:
|
||||
ZoneInfo(v)
|
||||
except (ZoneInfoNotFoundError, KeyError):
|
||||
raise ValueError(f'Invalid timezone: {v!r}')
|
||||
return v
|
||||
|
||||
|
||||
class TriggersModel(BaseModel):
|
||||
"""Container for trigger definitions. Exactly one trigger must be set."""
|
||||
|
||||
cron: CronTriggerModel | None = None
|
||||
|
||||
@model_validator(mode='after')
|
||||
def exactly_one_trigger(self) -> TriggersModel:
|
||||
defined = [name for name in ('cron',) if getattr(self, name) is not None]
|
||||
if len(defined) != 1:
|
||||
raise ValueError(f'Exactly one trigger must be defined, got: {defined}')
|
||||
return self
|
||||
|
||||
|
||||
class AutomationConfigModel(BaseModel):
|
||||
"""Top-level automation config schema."""
|
||||
|
||||
name: str
|
||||
triggers: TriggersModel
|
||||
description: str = ''
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
v = v.strip()
|
||||
if not (1 <= len(v) <= 200):
|
||||
raise ValueError('name must be between 1 and 200 characters')
|
||||
return v
|
||||
|
||||
|
||||
def validate_config(config: dict) -> AutomationConfigModel:
|
||||
"""Validate a ``__config__`` dict against the automation schema.
|
||||
|
||||
Returns the parsed :class:`AutomationConfigModel` or raises
|
||||
:class:`pydantic.ValidationError`.
|
||||
"""
|
||||
return AutomationConfigModel.model_validate(config)
|
||||
36
enterprise/services/automation_event_publisher.py
Normal file
36
enterprise/services/automation_event_publisher.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Publish automation events and notify listeners via PostgreSQL NOTIFY."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.automation_event import AutomationEvent
|
||||
|
||||
|
||||
def publish_automation_event(
|
||||
session: Session,
|
||||
source_type: str,
|
||||
payload: dict,
|
||||
dedup_key: str,
|
||||
metadata: dict | None = None,
|
||||
) -> AutomationEvent:
|
||||
"""Create an :class:`AutomationEvent` and add it to the session.
|
||||
|
||||
The caller is responsible for committing (or flushing) the session.
|
||||
"""
|
||||
event = AutomationEvent(
|
||||
source_type=source_type,
|
||||
payload=payload,
|
||||
dedup_key=dedup_key,
|
||||
metadata_=metadata,
|
||||
)
|
||||
session.add(event)
|
||||
return event
|
||||
|
||||
|
||||
def pg_notify_new_event(session: Session, event_id: int) -> None:
|
||||
"""Send a PostgreSQL ``NOTIFY`` on the ``automation_events`` channel."""
|
||||
session.execute(
|
||||
text("SELECT pg_notify('automation_events', :event_id)"),
|
||||
{'event_id': str(event_id)},
|
||||
)
|
||||
@@ -1,555 +0,0 @@
|
||||
"""Automation executor — processes events, claims and executes runs.
|
||||
|
||||
The executor is a long-running process with three phases:
|
||||
1. Process inbox: match NEW events to automations, create PENDING runs
|
||||
2. Claim and execute: claim PENDING runs, submit to V1 API, heartbeat
|
||||
3. Stale recovery: recover RUNNING runs with expired heartbeats
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from services.openhands_api_client import OpenHandsAPIClient
|
||||
from sqlalchemy import or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.automation import Automation, AutomationRun
|
||||
from storage.automation_event import AutomationEvent
|
||||
|
||||
logger = logging.getLogger('saas.automation.executor')
|
||||
|
||||
# Environment-configurable settings
|
||||
POLL_INTERVAL_SECONDS = float(os.getenv('POLL_INTERVAL_SECONDS', '30'))
|
||||
HEARTBEAT_INTERVAL_SECONDS = float(os.getenv('HEARTBEAT_INTERVAL_SECONDS', '60'))
|
||||
RUN_TIMEOUT_SECONDS = float(os.getenv('RUN_TIMEOUT_SECONDS', '7200'))
|
||||
MAX_CONCURRENT_RUNS = int(os.getenv('MAX_CONCURRENT_RUNS', '5'))
|
||||
STALE_THRESHOLD_MINUTES = 5
|
||||
MAX_EVENTS_PER_BATCH = 50
|
||||
MAX_RETRIES_DEFAULT = 3
|
||||
|
||||
# Terminal conversation statuses
|
||||
TERMINAL_STATUSES = frozenset({'STOPPED', 'ERROR', 'COMPLETED', 'CANCELLED'})
|
||||
|
||||
# Shutdown flag — set by signal handlers
|
||||
_shutdown_event: asyncio.Event | None = None
|
||||
|
||||
# Background task tracking for graceful shutdown
|
||||
_pending_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def get_shutdown_event() -> asyncio.Event:
|
||||
global _shutdown_event
|
||||
if _shutdown_event is None:
|
||||
_shutdown_event = asyncio.Event()
|
||||
return _shutdown_event
|
||||
|
||||
|
||||
def should_continue() -> bool:
|
||||
return not get_shutdown_event().is_set()
|
||||
|
||||
|
||||
def request_shutdown() -> None:
|
||||
get_shutdown_event().set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 1: Process inbox (event matching)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def find_matching_automations(
|
||||
session: AsyncSession, event: AutomationEvent
|
||||
) -> list[Automation]:
|
||||
"""Find automations that match the given event.
|
||||
|
||||
Phase 1 supports cron and manual triggers only — both carry
|
||||
``automation_id`` in the event payload.
|
||||
"""
|
||||
source_type = event.source_type
|
||||
payload = event.payload
|
||||
if payload is None:
|
||||
logger.error('Event %s has None payload — possible data corruption', event.id)
|
||||
return []
|
||||
|
||||
if source_type in ('cron', 'manual'):
|
||||
automation_id = payload.get('automation_id')
|
||||
if not automation_id:
|
||||
logger.warning(
|
||||
'Event %s (source=%s) missing automation_id in payload',
|
||||
event.id,
|
||||
source_type,
|
||||
)
|
||||
return []
|
||||
|
||||
result = await session.execute(
|
||||
select(Automation).where(
|
||||
Automation.id == automation_id,
|
||||
Automation.enabled.is_(True),
|
||||
)
|
||||
)
|
||||
automation = result.scalar_one_or_none()
|
||||
return [automation] if automation else []
|
||||
|
||||
logger.debug('Unhandled event source_type=%s for event %s', source_type, event.id)
|
||||
return []
|
||||
|
||||
|
||||
async def process_new_events(session: AsyncSession) -> int:
|
||||
"""Claim NEW events from inbox, match to automations, create runs.
|
||||
|
||||
Returns the number of events processed.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(AutomationEvent)
|
||||
.where(AutomationEvent.status == 'NEW')
|
||||
.order_by(AutomationEvent.created_at)
|
||||
.limit(MAX_EVENTS_PER_BATCH)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
events = list(result.scalars())
|
||||
|
||||
processed = 0
|
||||
for event in events:
|
||||
try:
|
||||
automations = await find_matching_automations(session, event)
|
||||
if not automations:
|
||||
event.status = 'NO_MATCH'
|
||||
event.processed_at = utc_now()
|
||||
else:
|
||||
for automation in automations:
|
||||
run = AutomationRun(
|
||||
id=uuid4().hex,
|
||||
automation_id=automation.id,
|
||||
event_id=event.id,
|
||||
status='PENDING',
|
||||
event_payload=event.payload,
|
||||
)
|
||||
session.add(run)
|
||||
event.status = 'PROCESSED'
|
||||
event.processed_at = utc_now()
|
||||
processed += 1
|
||||
except Exception as e:
|
||||
logger.exception('Error processing event %s', event.id)
|
||||
event.status = 'ERROR'
|
||||
event.error_detail = f'Failed during event matching: {type(e).__name__}: {e}'
|
||||
event.processed_at = utc_now()
|
||||
|
||||
if processed:
|
||||
await session.commit()
|
||||
logger.info('Processed %d events', processed)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2: Claim and execute runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def resolve_user_api_key(session: AsyncSession, user_id: str) -> str | None:
|
||||
"""Look up a user's API key from the api_keys table.
|
||||
|
||||
Returns the first active key found, or None.
|
||||
"""
|
||||
from storage.api_key import ApiKey
|
||||
|
||||
result = await session.execute(
|
||||
select(ApiKey.key).where(ApiKey.user_id == user_id).limit(1)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return row
|
||||
|
||||
|
||||
async def download_automation_file(file_store_key: str) -> bytes:
|
||||
"""Download the automation .py file from object storage."""
|
||||
try:
|
||||
from openhands.server.shared import file_store
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
'file_store is not available — ensure the enterprise server '
|
||||
'has been initialised before calling download_automation_file'
|
||||
) from exc
|
||||
|
||||
content = file_store.read(file_store_key)
|
||||
if isinstance(content, str):
|
||||
return content.encode('utf-8')
|
||||
return content
|
||||
|
||||
|
||||
def is_terminal(conversation: dict) -> bool:
|
||||
"""Check if a conversation has reached a terminal status."""
|
||||
status = (conversation.get('status') or '').upper()
|
||||
return status in TERMINAL_STATUSES
|
||||
|
||||
|
||||
async def _prepare_run(
|
||||
run: AutomationRun,
|
||||
automation: Automation,
|
||||
session_factory: object,
|
||||
) -> tuple[str, bytes]:
|
||||
"""Resolve the user's API key and download the automation file.
|
||||
|
||||
Returns:
|
||||
(api_key, automation_file) tuple ready for submission.
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key is found.
|
||||
RuntimeError: If file_store is unavailable.
|
||||
"""
|
||||
async with session_factory() as key_session:
|
||||
api_key = await resolve_user_api_key(key_session, automation.user_id)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(f'No API key found for user {automation.user_id}')
|
||||
|
||||
automation_file = await download_automation_file(automation.file_store_key)
|
||||
return api_key, automation_file
|
||||
|
||||
|
||||
async def _monitor_conversation(
|
||||
run: AutomationRun,
|
||||
conversation_id: str,
|
||||
api_client: OpenHandsAPIClient,
|
||||
api_key: str,
|
||||
session_factory: object,
|
||||
) -> bool:
|
||||
"""Monitor a conversation until completion or timeout.
|
||||
|
||||
Returns True if completed successfully, False if shutdown requested.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the run exceeds RUN_TIMEOUT_SECONDS.
|
||||
"""
|
||||
start_time = utc_now()
|
||||
while should_continue():
|
||||
elapsed = (utc_now() - start_time).total_seconds()
|
||||
if elapsed > RUN_TIMEOUT_SECONDS:
|
||||
raise TimeoutError(f'Run exceeded {RUN_TIMEOUT_SECONDS}s timeout')
|
||||
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
|
||||
|
||||
# Update heartbeat
|
||||
async with session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, run.id)
|
||||
if run_obj:
|
||||
run_obj.heartbeat_at = utc_now()
|
||||
await session.commit()
|
||||
|
||||
# Check conversation status
|
||||
conversation = (
|
||||
await api_client.get_conversation(api_key, conversation_id) or {}
|
||||
)
|
||||
if is_terminal(conversation):
|
||||
return True
|
||||
|
||||
return False # shutdown requested
|
||||
|
||||
|
||||
async def _submit_and_monitor(
|
||||
run: AutomationRun,
|
||||
api_key: str,
|
||||
automation_file: bytes,
|
||||
automation: Automation,
|
||||
api_client: OpenHandsAPIClient,
|
||||
session_factory: object,
|
||||
) -> None:
|
||||
"""Submit the automation to the V1 API and monitor until completion.
|
||||
|
||||
Updates the run's conversation_id, sends heartbeats, and marks the
|
||||
final status when the conversation reaches a terminal state.
|
||||
"""
|
||||
conversation = await api_client.start_conversation(
|
||||
api_key=api_key,
|
||||
automation_file=automation_file,
|
||||
title=f'Automation: {automation.name}',
|
||||
event_payload=run.event_payload,
|
||||
)
|
||||
|
||||
conversation_id = conversation.get('app_conversation_id') or conversation.get(
|
||||
'conversation_id'
|
||||
)
|
||||
|
||||
# Persist conversation ID
|
||||
async with session_factory() as update_session:
|
||||
run_obj = await update_session.get(AutomationRun, run.id)
|
||||
if run_obj:
|
||||
run_obj.conversation_id = conversation_id
|
||||
await update_session.commit()
|
||||
|
||||
# Monitor with heartbeats
|
||||
completed = await _monitor_conversation(
|
||||
run, conversation_id, api_client, api_key, session_factory
|
||||
)
|
||||
|
||||
# Update final status
|
||||
async with session_factory() as final_session:
|
||||
run_obj = await final_session.get(AutomationRun, run.id)
|
||||
if run_obj:
|
||||
if not completed:
|
||||
# Leave as RUNNING — stale recovery will handle it if needed.
|
||||
# The conversation may still be running on the API side.
|
||||
logger.info(
|
||||
'Run %s left as RUNNING due to executor shutdown', run.id
|
||||
)
|
||||
else:
|
||||
run_obj.status = 'COMPLETED'
|
||||
run_obj.completed_at = utc_now()
|
||||
logger.info('Run %s completed successfully', run.id)
|
||||
await final_session.commit()
|
||||
|
||||
|
||||
async def execute_run(
|
||||
run: AutomationRun,
|
||||
automation: Automation,
|
||||
api_client: OpenHandsAPIClient,
|
||||
session_factory: object,
|
||||
) -> None:
|
||||
"""Execute a single automation run end-to-end.
|
||||
|
||||
Orchestrates preparation (API key + file download) and submission/monitoring.
|
||||
On failure, marks the run for retry or dead-letter.
|
||||
"""
|
||||
try:
|
||||
api_key, automation_file = await _prepare_run(
|
||||
run, automation, session_factory
|
||||
)
|
||||
await _submit_and_monitor(
|
||||
run, api_key, automation_file, automation, api_client, session_factory
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception('Run %s failed: %s', run.id, e)
|
||||
await _mark_run_failed(run, str(e), session_factory)
|
||||
|
||||
|
||||
async def _mark_run_failed(
|
||||
run: AutomationRun, error: str, session_factory: object
|
||||
) -> None:
|
||||
"""Mark a run as FAILED or return to PENDING for retry."""
|
||||
async with session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, run.id)
|
||||
if not run_obj:
|
||||
return
|
||||
|
||||
run_obj.retry_count = (run_obj.retry_count or 0) + 1
|
||||
run_obj.error_detail = error
|
||||
|
||||
if run_obj.retry_count >= (run_obj.max_retries or MAX_RETRIES_DEFAULT):
|
||||
run_obj.status = 'DEAD_LETTER'
|
||||
run_obj.completed_at = utc_now()
|
||||
logger.error(
|
||||
'Run %s moved to DEAD_LETTER after %d retries',
|
||||
run.id,
|
||||
run_obj.retry_count,
|
||||
)
|
||||
else:
|
||||
run_obj.status = 'PENDING'
|
||||
run_obj.claimed_by = None
|
||||
backoff_seconds = 30 * (2 ** (run_obj.retry_count - 1))
|
||||
run_obj.next_retry_at = utc_now() + timedelta(seconds=backoff_seconds)
|
||||
logger.warning(
|
||||
'Run %s returned to PENDING, retry %d/%d in %ds',
|
||||
run.id,
|
||||
run_obj.retry_count,
|
||||
run_obj.max_retries or MAX_RETRIES_DEFAULT,
|
||||
backoff_seconds,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def claim_and_execute_runs(
|
||||
session: AsyncSession,
|
||||
executor_id: str,
|
||||
api_client: OpenHandsAPIClient,
|
||||
session_factory: object,
|
||||
) -> bool:
|
||||
"""Claim a PENDING run and start executing it.
|
||||
|
||||
Returns True if a run was claimed, False otherwise.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(AutomationRun)
|
||||
.where(
|
||||
AutomationRun.status == 'PENDING',
|
||||
or_(
|
||||
AutomationRun.next_retry_at.is_(None),
|
||||
AutomationRun.next_retry_at <= utc_now(),
|
||||
),
|
||||
)
|
||||
.order_by(AutomationRun.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
run = result.scalar_one_or_none()
|
||||
if not run:
|
||||
return False
|
||||
|
||||
# Claim the run
|
||||
run.status = 'RUNNING'
|
||||
run.claimed_by = executor_id
|
||||
run.claimed_at = utc_now()
|
||||
run.heartbeat_at = utc_now()
|
||||
run.started_at = utc_now()
|
||||
await session.commit()
|
||||
|
||||
# Load automation for the run
|
||||
auto_result = await session.execute(
|
||||
select(Automation).where(Automation.id == run.automation_id)
|
||||
)
|
||||
automation = auto_result.scalar_one_or_none()
|
||||
if not automation:
|
||||
logger.error('Automation %s not found for run %s', run.automation_id, run.id)
|
||||
await _mark_run_failed(
|
||||
run, f'Automation {run.automation_id} not found', session_factory
|
||||
)
|
||||
return True
|
||||
|
||||
# Execute in background (long-running) with task tracking
|
||||
task = asyncio.create_task(
|
||||
execute_run(run, automation, api_client, session_factory),
|
||||
name=f'execute-run-{run.id}',
|
||||
)
|
||||
_pending_tasks.add(task)
|
||||
task.add_done_callback(_pending_tasks.discard)
|
||||
|
||||
logger.info(
|
||||
'Claimed run %s (automation=%s) by executor %s',
|
||||
run.id,
|
||||
run.automation_id,
|
||||
executor_id,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 3: Stale run recovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def recover_stale_runs(session: AsyncSession) -> int:
|
||||
"""Mark RUNNING runs with expired heartbeats as PENDING for retry.
|
||||
|
||||
Returns the number of recovered runs.
|
||||
"""
|
||||
stale_threshold = utc_now() - timedelta(minutes=STALE_THRESHOLD_MINUTES)
|
||||
timeout_threshold = utc_now() - timedelta(seconds=RUN_TIMEOUT_SECONDS)
|
||||
|
||||
# Recover stale runs (heartbeat expired)
|
||||
result = await session.execute(
|
||||
update(AutomationRun)
|
||||
.where(
|
||||
AutomationRun.status == 'RUNNING',
|
||||
AutomationRun.heartbeat_at < stale_threshold,
|
||||
AutomationRun.heartbeat_at >= timeout_threshold,
|
||||
)
|
||||
.values(
|
||||
status='PENDING',
|
||||
claimed_by=None,
|
||||
retry_count=AutomationRun.retry_count + 1,
|
||||
next_retry_at=utc_now() + timedelta(seconds=30),
|
||||
)
|
||||
.returning(AutomationRun.id)
|
||||
)
|
||||
recovered_rows = result.fetchall()
|
||||
|
||||
# Mark truly timed-out runs as DEAD_LETTER
|
||||
timeout_result = await session.execute(
|
||||
update(AutomationRun)
|
||||
.where(
|
||||
AutomationRun.status == 'RUNNING',
|
||||
AutomationRun.heartbeat_at < timeout_threshold,
|
||||
)
|
||||
.values(
|
||||
status='DEAD_LETTER',
|
||||
error_detail='Run exceeded timeout',
|
||||
completed_at=utc_now(),
|
||||
)
|
||||
.returning(AutomationRun.id)
|
||||
)
|
||||
timed_out_rows = timeout_result.fetchall()
|
||||
|
||||
await session.commit()
|
||||
|
||||
recovered_count = len(recovered_rows)
|
||||
timed_out_count = len(timed_out_rows)
|
||||
|
||||
if recovered_count:
|
||||
logger.warning('Recovered %d stale automation runs', recovered_count)
|
||||
if timed_out_count:
|
||||
logger.warning(
|
||||
'Marked %d automation runs as DEAD_LETTER (timeout)', timed_out_count
|
||||
)
|
||||
|
||||
return recovered_count + timed_out_count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main executor loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def executor_main(session_factory: object | None = None) -> None:
|
||||
"""Main executor loop.
|
||||
|
||||
Args:
|
||||
session_factory: Async context manager that yields AsyncSession instances.
|
||||
If None, uses the default ``a_session_maker`` from database module.
|
||||
"""
|
||||
if session_factory is None:
|
||||
from storage.database import a_session_maker
|
||||
|
||||
session_factory = a_session_maker
|
||||
|
||||
executor_id = f'executor-{socket.gethostname()}-{os.getpid()}'
|
||||
api_url = os.getenv('OPENHANDS_API_URL', 'http://openhands-service:3000')
|
||||
api_client = OpenHandsAPIClient(base_url=api_url)
|
||||
|
||||
logger.info(
|
||||
'Automation executor %s starting (api_url=%s, poll=%ss, heartbeat=%ss)',
|
||||
executor_id,
|
||||
api_url,
|
||||
POLL_INTERVAL_SECONDS,
|
||||
HEARTBEAT_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
while should_continue():
|
||||
try:
|
||||
async with session_factory() as session:
|
||||
await process_new_events(session)
|
||||
|
||||
async with session_factory() as session:
|
||||
await claim_and_execute_runs(
|
||||
session, executor_id, api_client, session_factory
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
await recover_stale_runs(session)
|
||||
|
||||
except Exception:
|
||||
logger.exception('Error in executor main loop iteration')
|
||||
|
||||
# Wait for next poll interval (or early wakeup on shutdown)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
get_shutdown_event().wait(),
|
||||
timeout=POLL_INTERVAL_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal — poll interval elapsed
|
||||
|
||||
finally:
|
||||
if _pending_tasks:
|
||||
logger.info(
|
||||
'Waiting for %d running tasks to complete...', len(_pending_tasks)
|
||||
)
|
||||
await asyncio.gather(*_pending_tasks, return_exceptions=True)
|
||||
await api_client.close()
|
||||
logger.info('Automation executor %s shut down', executor_id)
|
||||
56
enterprise/services/automation_file_generator.py
Normal file
56
enterprise/services/automation_file_generator.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Generate automation Python files from user-provided parameters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
|
||||
|
||||
def generate_automation_file(
|
||||
name: str,
|
||||
schedule: str,
|
||||
timezone: str,
|
||||
prompt: str,
|
||||
) -> str:
|
||||
"""Return a complete, valid Python file string for an automation.
|
||||
|
||||
The generated file includes a ``__config__`` dict that can be round-tripped
|
||||
through :func:`services.automation_config.extract_config` and
|
||||
:func:`services.automation_config.validate_config`.
|
||||
"""
|
||||
# Use repr() for safe string escaping — handles backslashes, quotes, etc.
|
||||
r_name = repr(name)
|
||||
r_schedule = repr(schedule)
|
||||
r_timezone = repr(timezone)
|
||||
r_prompt = repr(prompt)
|
||||
|
||||
# Build a safe docstring — replace double quotes with single quotes to
|
||||
# prevent triple-quote breakage in the docstring.
|
||||
safe_docstring_name = name.replace('"', "'")
|
||||
|
||||
return textwrap.dedent(f'''\
|
||||
"""{safe_docstring_name} — auto-generated automation."""
|
||||
|
||||
__config__ = {{
|
||||
"name": {r_name},
|
||||
"triggers": {{
|
||||
"cron": {{
|
||||
"schedule": {r_schedule},
|
||||
"timezone": {r_timezone},
|
||||
}}
|
||||
}},
|
||||
}}
|
||||
|
||||
import os
|
||||
from openhands.sdk import LLM, Conversation
|
||||
from openhands.tools.preset.default import get_default_agent
|
||||
|
||||
llm = LLM(
|
||||
model=os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-5-20250929"),
|
||||
api_key=os.getenv("LLM_API_KEY"),
|
||||
base_url=os.getenv("LLM_BASE_URL"),
|
||||
)
|
||||
agent = get_default_agent(llm=llm, cli_mode=True)
|
||||
conversation = Conversation(agent=agent, workspace=os.getcwd())
|
||||
conversation.send_message({r_prompt})
|
||||
conversation.run()
|
||||
''')
|
||||
@@ -1,93 +0,0 @@
|
||||
"""HTTP client for the main OpenHands V1 API (internal cluster calls).
|
||||
|
||||
Used by the automation executor to create and monitor conversations
|
||||
in the main OpenHands server.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger('saas.automation.api_client')
|
||||
|
||||
|
||||
def _raise_with_body(resp: httpx.Response) -> None:
|
||||
"""Call raise_for_status, enriching the error with the response body."""
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_body = resp.text[:500] if resp.text else 'no response body'
|
||||
raise httpx.HTTPStatusError(
|
||||
f'{e.args[0]} — Response: {error_body}',
|
||||
request=e.request,
|
||||
response=e.response,
|
||||
) from e
|
||||
|
||||
|
||||
class OpenHandsAPIClient:
|
||||
"""Async HTTP client for the OpenHands V1 API."""
|
||||
|
||||
def __init__(self, base_url: str = 'http://openhands-service:3000'):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0)
|
||||
|
||||
async def start_conversation(
|
||||
self,
|
||||
api_key: str,
|
||||
automation_file: bytes,
|
||||
title: str,
|
||||
event_payload: dict | None = None,
|
||||
) -> dict:
|
||||
"""Submit an SDK script for sandboxed execution via V1 API.
|
||||
|
||||
Args:
|
||||
api_key: User's API key for authentication.
|
||||
automation_file: Raw bytes of the .py automation script.
|
||||
title: Display title for the conversation.
|
||||
event_payload: Optional trigger event data (injected as env var).
|
||||
|
||||
Returns:
|
||||
Parsed JSON response containing conversation details.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the API returns a non-2xx status.
|
||||
"""
|
||||
resp = await self.client.post(
|
||||
'/api/v1/app-conversations',
|
||||
json={
|
||||
'automation_file': base64.b64encode(automation_file).decode(),
|
||||
'trigger': 'automation',
|
||||
'title': title,
|
||||
'event_payload': event_payload,
|
||||
},
|
||||
headers={'Authorization': f'Bearer {api_key}'},
|
||||
)
|
||||
_raise_with_body(resp)
|
||||
return resp.json()
|
||||
|
||||
async def get_conversation(self, api_key: str, conversation_id: str) -> dict | None:
|
||||
"""Get conversation status.
|
||||
|
||||
Args:
|
||||
api_key: User's API key for authentication.
|
||||
conversation_id: The conversation ID to look up.
|
||||
|
||||
Returns:
|
||||
Conversation data dict, or None if not found.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the API returns a non-2xx status.
|
||||
"""
|
||||
resp = await self.client.get(
|
||||
'/api/v1/app-conversations',
|
||||
params={'ids': [conversation_id]},
|
||||
headers={'Authorization': f'Bearer {api_key}'},
|
||||
)
|
||||
_raise_with_body(resp)
|
||||
conversations = resp.json()
|
||||
return conversations[0] if conversations else None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying HTTP client."""
|
||||
await self.client.aclose()
|
||||
@@ -1,61 +1,78 @@
|
||||
"""SQLAlchemy models for automations and automation runs.
|
||||
|
||||
Stub for Task 1 (Data Foundation). These models will be replaced when Task 1
|
||||
is merged into automations-phase1.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
func,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import JSON
|
||||
from storage.base import Base
|
||||
|
||||
# Use JSON with JSONB variant so models work on SQLite (tests) and PostgreSQL (prod)
|
||||
_JsonType = JSON().with_variant(JSONB(), 'postgresql')
|
||||
|
||||
|
||||
class Automation(Base): # type: ignore
|
||||
"""Model for storing automation definitions."""
|
||||
|
||||
class Automation(Base):
|
||||
__tablename__ = 'automations'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False, index=True)
|
||||
org_id = Column(String, nullable=True, index=True)
|
||||
|
||||
user_id = Column(String, nullable=False)
|
||||
org_id = Column(String, nullable=True)
|
||||
name = Column(String, nullable=False)
|
||||
enabled = Column(Boolean, nullable=False, server_default=text('true'))
|
||||
|
||||
config = Column(JSON, nullable=False)
|
||||
enabled = Column(Boolean, nullable=False, server_default=text('TRUE'))
|
||||
config = Column(_JsonType, nullable=False)
|
||||
trigger_type = Column(String, nullable=False)
|
||||
file_store_key = Column(String, nullable=False)
|
||||
|
||||
last_triggered_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
runs = relationship('AutomationRun', back_populates='automation')
|
||||
runs = relationship('AutomationRun', back_populates='automation', cascade='all, delete-orphan')
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_automations_user_id', 'user_id'),
|
||||
Index('ix_automations_org_id', 'org_id'),
|
||||
Index('ix_automations_enabled_trigger', 'enabled', 'trigger_type'),
|
||||
)
|
||||
|
||||
|
||||
class AutomationRun(Base):
|
||||
class AutomationRun(Base): # type: ignore
|
||||
"""Model for storing automation run records."""
|
||||
|
||||
__tablename__ = 'automation_runs'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
automation_id = Column(
|
||||
String, ForeignKey('automations.id', ondelete='CASCADE'), nullable=False
|
||||
String,
|
||||
ForeignKey('automations.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
)
|
||||
event_id = Column(
|
||||
BigInteger,
|
||||
ForeignKey('automation_events.id'),
|
||||
nullable=True,
|
||||
)
|
||||
event_id = Column(Integer, ForeignKey('automation_events.id'), nullable=True)
|
||||
conversation_id = Column(String, nullable=True)
|
||||
status = Column(String, nullable=False, server_default=text("'PENDING'"))
|
||||
claimed_by = Column(String, nullable=True)
|
||||
@@ -64,14 +81,29 @@ class AutomationRun(Base):
|
||||
retry_count = Column(Integer, nullable=False, server_default=text('0'))
|
||||
max_retries = Column(Integer, nullable=False, server_default=text('3'))
|
||||
next_retry_at = Column(DateTime(timezone=True), nullable=True)
|
||||
event_payload = Column(JSON, nullable=True)
|
||||
event_payload = Column(_JsonType, nullable=True)
|
||||
error_detail = Column(Text, nullable=True)
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
automation = relationship('Automation', back_populates='runs')
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
'ix_automation_runs_claimable',
|
||||
'status',
|
||||
'next_retry_at',
|
||||
postgresql_where=text("status = 'PENDING' AND (next_retry_at IS NULL OR next_retry_at <= now())"),
|
||||
),
|
||||
Index('ix_automation_runs_automation_id', 'automation_id'),
|
||||
Index(
|
||||
'ix_automation_runs_heartbeat',
|
||||
'heartbeat_at',
|
||||
postgresql_where=text("status = 'RUNNING'"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,27 +1,37 @@
|
||||
"""SQLAlchemy model for automation events (the inbox).
|
||||
from __future__ import annotations
|
||||
|
||||
Stub for Task 1 (Data Foundation). This model will be replaced when Task 1
|
||||
is merged into automations-phase1.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String, Text, text
|
||||
from sqlalchemy import BigInteger, Column, DateTime, Index, String, Text, UniqueConstraint, text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.types import JSON
|
||||
from storage.base import Base
|
||||
|
||||
_JsonType = JSON().with_variant(JSONB(), 'postgresql')
|
||||
|
||||
|
||||
class AutomationEvent(Base): # type: ignore
|
||||
"""Model for storing raw automation trigger events."""
|
||||
|
||||
class AutomationEvent(Base):
|
||||
__tablename__ = 'automation_events'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
source_type = Column(String, nullable=False)
|
||||
payload = Column(JSON, nullable=False)
|
||||
metadata_ = Column('metadata', JSON, nullable=True)
|
||||
payload = Column(_JsonType, nullable=False)
|
||||
metadata_ = Column('metadata', _JsonType, nullable=True)
|
||||
dedup_key = Column(String, nullable=False, unique=True)
|
||||
status = Column(String, nullable=False, server_default=text("'NEW'"))
|
||||
error_detail = Column(Text, nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
processed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('dedup_key', name='uq_automation_events_dedup'),
|
||||
Index(
|
||||
'ix_automation_events_new',
|
||||
'created_at',
|
||||
postgresql_where=text("status = 'NEW'"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""Shared fixtures for services tests.
|
||||
|
||||
Note: We pre-load ``storage`` as a namespace package to avoid the heavy
|
||||
``storage/__init__.py`` that imports the entire enterprise model graph.
|
||||
This must happen *before* any ``from storage.…`` import.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Prevent storage/__init__.py from loading the full model graph.
|
||||
# We only need the lightweight automation models for these tests.
|
||||
if 'storage' not in sys.modules:
|
||||
import pathlib
|
||||
|
||||
_storage_dir = str(pathlib.Path(__file__).resolve().parents[3] / 'storage')
|
||||
_mod = types.ModuleType('storage')
|
||||
_mod.__path__ = [_storage_dir]
|
||||
sys.modules['storage'] = _mod
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.automation import Automation, AutomationRun # noqa: F401
|
||||
from storage.automation_event import AutomationEvent # noqa: F401
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_factory(async_engine):
|
||||
"""Create an async session factory that yields context-managed sessions."""
|
||||
factory = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _session_ctx():
|
||||
async with factory() as session:
|
||||
yield session
|
||||
|
||||
return _session_ctx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_session_factory):
|
||||
"""Create a single async session for testing."""
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
160
enterprise/tests/unit/services/test_automation_config.py
Normal file
160
enterprise/tests/unit/services/test_automation_config.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Tests for automation config extraction and validation."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from services.automation_config import extract_config, validate_config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractConfig:
|
||||
def test_plain_dict(self):
|
||||
source = '''
|
||||
__config__ = {
|
||||
"name": "Daily Report",
|
||||
"triggers": {"cron": {"schedule": "0 9 * * 1"}},
|
||||
}
|
||||
'''
|
||||
cfg = extract_config(source)
|
||||
assert cfg['name'] == 'Daily Report'
|
||||
assert cfg['triggers']['cron']['schedule'] == '0 9 * * 1'
|
||||
|
||||
def test_annotated_assignment(self):
|
||||
source = '''
|
||||
__config__: dict = {
|
||||
"name": "Annotated",
|
||||
"triggers": {"cron": {"schedule": "*/5 * * * *"}},
|
||||
"description": "annotated config",
|
||||
}
|
||||
'''
|
||||
cfg = extract_config(source)
|
||||
assert cfg['name'] == 'Annotated'
|
||||
assert cfg['description'] == 'annotated config'
|
||||
|
||||
def test_no_config_raises(self):
|
||||
source = 'x = 1\ny = 2\n'
|
||||
with pytest.raises(ValueError, match='__config__ not found'):
|
||||
extract_config(source)
|
||||
|
||||
def test_non_literal_value_raises(self):
|
||||
source = '__config__ = some_function()\n'
|
||||
with pytest.raises(ValueError, match='literal expression'):
|
||||
extract_config(source)
|
||||
|
||||
def test_bad_syntax_raises(self):
|
||||
source = 'def foo(\n'
|
||||
with pytest.raises(ValueError, match='Failed to parse'):
|
||||
extract_config(source)
|
||||
|
||||
def test_config_not_dict_raises(self):
|
||||
source = '__config__ = [1, 2, 3]\n'
|
||||
with pytest.raises(ValueError, match='must be a dict'):
|
||||
extract_config(source)
|
||||
|
||||
def test_config_with_surrounding_code(self):
|
||||
source = '''
|
||||
import os
|
||||
|
||||
__config__ = {"name": "Mixed", "triggers": {"cron": {"schedule": "0 0 * * *"}}}
|
||||
|
||||
def main():
|
||||
pass
|
||||
'''
|
||||
cfg = extract_config(source)
|
||||
assert cfg['name'] == 'Mixed'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateConfig:
|
||||
def test_valid_cron_config(self):
|
||||
cfg = {
|
||||
'name': 'My Automation',
|
||||
'triggers': {
|
||||
'cron': {'schedule': '0 9 * * 1', 'timezone': 'America/New_York'},
|
||||
},
|
||||
}
|
||||
model = validate_config(cfg)
|
||||
assert model.name == 'My Automation'
|
||||
assert model.triggers.cron is not None
|
||||
assert model.triggers.cron.schedule == '0 9 * * 1'
|
||||
assert model.triggers.cron.timezone == 'America/New_York'
|
||||
|
||||
def test_valid_cron_default_timezone(self):
|
||||
cfg = {
|
||||
'name': 'Simple',
|
||||
'triggers': {'cron': {'schedule': '*/5 * * * *'}},
|
||||
}
|
||||
model = validate_config(cfg)
|
||||
assert model.triggers.cron is not None
|
||||
assert model.triggers.cron.timezone == 'UTC'
|
||||
|
||||
def test_valid_with_description(self):
|
||||
cfg = {
|
||||
'name': 'Described',
|
||||
'triggers': {'cron': {'schedule': '0 0 * * *'}},
|
||||
'description': 'A helpful description',
|
||||
}
|
||||
model = validate_config(cfg)
|
||||
assert model.description == 'A helpful description'
|
||||
|
||||
def test_missing_name_raises(self):
|
||||
cfg = {
|
||||
'triggers': {'cron': {'schedule': '0 0 * * *'}},
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_empty_name_raises(self):
|
||||
cfg = {
|
||||
'name': '',
|
||||
'triggers': {'cron': {'schedule': '0 0 * * *'}},
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_name_too_long_raises(self):
|
||||
cfg = {
|
||||
'name': 'x' * 201,
|
||||
'triggers': {'cron': {'schedule': '0 0 * * *'}},
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_missing_triggers_raises(self):
|
||||
cfg = {'name': 'No Triggers'}
|
||||
with pytest.raises(ValidationError):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_empty_triggers_raises(self):
|
||||
cfg = {'name': 'Empty', 'triggers': {}}
|
||||
with pytest.raises(ValidationError, match='Exactly one trigger'):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_invalid_cron_expression_raises(self):
|
||||
cfg = {
|
||||
'name': 'Bad Cron',
|
||||
'triggers': {'cron': {'schedule': 'not-a-cron'}},
|
||||
}
|
||||
with pytest.raises(ValidationError, match='Invalid cron expression'):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_invalid_cron_too_few_fields(self):
|
||||
cfg = {
|
||||
'name': 'Short Cron',
|
||||
'triggers': {'cron': {'schedule': '* *'}},
|
||||
}
|
||||
with pytest.raises(ValidationError, match='Invalid cron expression'):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_name_at_boundary_200(self):
|
||||
cfg = {
|
||||
'name': 'x' * 200,
|
||||
'triggers': {'cron': {'schedule': '0 0 * * *'}},
|
||||
}
|
||||
model = validate_config(cfg)
|
||||
assert len(model.name) == 200
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Tests for automation event publisher."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from services.automation_event_publisher import pg_notify_new_event, publish_automation_event
|
||||
from storage.automation_event import AutomationEvent
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
def _make_engine():
|
||||
engine = create_engine('sqlite://', connect_args={'check_same_thread': False})
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
class TestPublishAutomationEvent:
|
||||
def test_creates_event_with_correct_fields(self):
|
||||
engine = _make_engine()
|
||||
Session = sessionmaker(bind=engine)
|
||||
with Session() as session:
|
||||
event = publish_automation_event(
|
||||
session=session,
|
||||
source_type='cron',
|
||||
payload={'automation_id': 'abc'},
|
||||
dedup_key='cron-abc-2025',
|
||||
metadata={'extra': 'data'},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
fetched = session.get(AutomationEvent, event.id)
|
||||
assert fetched is not None
|
||||
assert fetched.source_type == 'cron'
|
||||
assert fetched.payload == {'automation_id': 'abc'}
|
||||
assert fetched.dedup_key == 'cron-abc-2025'
|
||||
assert fetched.metadata_ == {'extra': 'data'}
|
||||
assert fetched.status == 'NEW'
|
||||
|
||||
def test_creates_event_without_metadata(self):
|
||||
engine = _make_engine()
|
||||
Session = sessionmaker(bind=engine)
|
||||
with Session() as session:
|
||||
event = publish_automation_event(
|
||||
session=session,
|
||||
source_type='manual',
|
||||
payload={'test': True},
|
||||
dedup_key='manual-123',
|
||||
)
|
||||
session.commit()
|
||||
|
||||
fetched = session.get(AutomationEvent, event.id)
|
||||
assert fetched is not None
|
||||
assert fetched.metadata_ is None
|
||||
|
||||
|
||||
class TestPgNotifyNewEvent:
|
||||
def test_pg_notify_executes_sql(self):
|
||||
"""pg_notify uses PostgreSQL-specific function; verify it at least
|
||||
constructs the correct SQL statement. On SQLite this will fail at
|
||||
execution, so we just verify the function doesn't error before execute."""
|
||||
mock_session = MagicMock()
|
||||
pg_notify_new_event(mock_session, 42)
|
||||
mock_session.execute.assert_called_once()
|
||||
call_args = mock_session.execute.call_args
|
||||
sql_text = str(call_args[0][0])
|
||||
assert 'pg_notify' in sql_text
|
||||
@@ -1,624 +0,0 @@
|
||||
"""Tests for the automation executor.
|
||||
|
||||
Uses real SQLite database operations for event processing, run claiming,
|
||||
and stale run recovery. HTTP calls to the V1 API are mocked.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from services.automation_executor import (
|
||||
_mark_run_failed,
|
||||
claim_and_execute_runs,
|
||||
find_matching_automations,
|
||||
is_terminal,
|
||||
process_new_events,
|
||||
recover_stale_runs,
|
||||
utc_now,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from storage.automation import Automation, AutomationRun
|
||||
from storage.automation_event import AutomationEvent
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_automation(
|
||||
automation_id: str = 'auto-1',
|
||||
user_id: str = 'user-1',
|
||||
enabled: bool = True,
|
||||
trigger_type: str = 'cron',
|
||||
name: str = 'Test Automation',
|
||||
) -> Automation:
|
||||
return Automation(
|
||||
id=automation_id,
|
||||
user_id=user_id,
|
||||
org_id='org-1',
|
||||
name=name,
|
||||
enabled=enabled,
|
||||
config={'triggers': {'cron': {'schedule': '0 9 * * 5'}}},
|
||||
trigger_type=trigger_type,
|
||||
file_store_key=f'automations/{automation_id}/script.py',
|
||||
)
|
||||
|
||||
|
||||
def make_event(
|
||||
source_type: str = 'cron',
|
||||
payload: dict | None = None,
|
||||
status: str = 'NEW',
|
||||
dedup_key: str | None = None,
|
||||
) -> AutomationEvent:
|
||||
return AutomationEvent(
|
||||
source_type=source_type,
|
||||
payload=payload or {'automation_id': 'auto-1'},
|
||||
dedup_key=dedup_key or f'dedup-{uuid4().hex[:8]}',
|
||||
status=status,
|
||||
created_at=utc_now(),
|
||||
)
|
||||
|
||||
|
||||
def make_run(
|
||||
run_id: str | None = None,
|
||||
automation_id: str = 'auto-1',
|
||||
status: str = 'PENDING',
|
||||
claimed_by: str | None = None,
|
||||
heartbeat_at: datetime | None = None,
|
||||
retry_count: int = 0,
|
||||
max_retries: int = 3,
|
||||
next_retry_at: datetime | None = None,
|
||||
) -> AutomationRun:
|
||||
return AutomationRun(
|
||||
id=run_id or uuid4().hex,
|
||||
automation_id=automation_id,
|
||||
status=status,
|
||||
claimed_by=claimed_by,
|
||||
heartbeat_at=heartbeat_at,
|
||||
retry_count=retry_count,
|
||||
max_retries=max_retries,
|
||||
next_retry_at=next_retry_at,
|
||||
event_payload={'automation_id': automation_id},
|
||||
created_at=utc_now(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_automations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_cron_event(async_session):
|
||||
"""Cron events match by automation_id in payload."""
|
||||
automation = make_automation()
|
||||
async_session.add(automation)
|
||||
await async_session.commit()
|
||||
|
||||
event = make_event(source_type='cron', payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == 'auto-1'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_manual_event(async_session):
|
||||
"""Manual events also match by automation_id in payload."""
|
||||
automation = make_automation()
|
||||
async_session.add(automation)
|
||||
await async_session.commit()
|
||||
|
||||
event = make_event(source_type='manual', payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == 'auto-1'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_disabled_automation(async_session):
|
||||
"""Disabled automations are not matched."""
|
||||
automation = make_automation(enabled=False)
|
||||
async_session.add(automation)
|
||||
await async_session.commit()
|
||||
|
||||
event = make_event(payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_missing_automation_id(async_session):
|
||||
"""Events without automation_id in payload return empty list."""
|
||||
event = make_event(payload={'something_else': 'value'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_nonexistent_automation(async_session):
|
||||
"""Events referencing a non-existent automation return empty list."""
|
||||
event = make_event(payload={'automation_id': 'nonexistent'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_unknown_source_type(async_session):
|
||||
"""Unknown source types return empty list."""
|
||||
event = make_event(source_type='unknown', payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_new_events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_creates_runs(async_session):
|
||||
"""Processing NEW events creates PENDING runs and marks events PROCESSED."""
|
||||
automation = make_automation()
|
||||
event = make_event(payload={'automation_id': 'auto-1'})
|
||||
async_session.add_all([automation, event])
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 1
|
||||
|
||||
# Event should be PROCESSED
|
||||
await async_session.refresh(event)
|
||||
assert event.status == 'PROCESSED'
|
||||
assert event.processed_at is not None
|
||||
|
||||
# A run should have been created
|
||||
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 1
|
||||
assert runs[0].automation_id == 'auto-1'
|
||||
assert runs[0].status == 'PENDING'
|
||||
assert runs[0].event_payload == {'automation_id': 'auto-1'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_no_match(async_session):
|
||||
"""Events with no matching automation are marked NO_MATCH."""
|
||||
event = make_event(payload={'automation_id': 'nonexistent'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 1
|
||||
|
||||
await async_session.refresh(event)
|
||||
assert event.status == 'NO_MATCH'
|
||||
assert event.processed_at is not None
|
||||
|
||||
# No runs created
|
||||
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_skips_processed(async_session):
|
||||
"""Already processed events are not re-processed."""
|
||||
event = make_event(status='PROCESSED')
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_multiple_events(async_session):
|
||||
"""Multiple NEW events are processed in one batch."""
|
||||
auto1 = make_automation(automation_id='auto-1')
|
||||
auto2 = make_automation(automation_id='auto-2', name='Auto 2')
|
||||
event1 = make_event(payload={'automation_id': 'auto-1'}, dedup_key='dedup-1')
|
||||
event2 = make_event(payload={'automation_id': 'auto-2'}, dedup_key='dedup-2')
|
||||
event3 = make_event(payload={'automation_id': 'nonexistent'}, dedup_key='dedup-3')
|
||||
async_session.add_all([auto1, auto2, event1, event2, event3])
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 3
|
||||
|
||||
# Two runs created (for auto-1 and auto-2), none for nonexistent
|
||||
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 2
|
||||
|
||||
await async_session.refresh(event1)
|
||||
await async_session.refresh(event2)
|
||||
await async_session.refresh(event3)
|
||||
assert event1.status == 'PROCESSED'
|
||||
assert event2.status == 'PROCESSED'
|
||||
assert event3.status == 'NO_MATCH'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# claim_and_execute_runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_claims_pending(
|
||||
async_session, async_session_factory
|
||||
):
|
||||
"""Claims a PENDING run and transitions to RUNNING."""
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='run-1')
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is True
|
||||
|
||||
await async_session.refresh(run)
|
||||
assert run.status == 'RUNNING'
|
||||
assert run.claimed_by == 'executor-test-1'
|
||||
assert run.claimed_at is not None
|
||||
assert run.heartbeat_at is not None
|
||||
assert run.started_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_no_pending(async_session, async_session_factory):
|
||||
"""Returns False when no PENDING runs exist."""
|
||||
api_client = AsyncMock()
|
||||
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_respects_next_retry_at(
|
||||
async_session, async_session_factory
|
||||
):
|
||||
"""Runs with future next_retry_at are not claimed."""
|
||||
automation = make_automation()
|
||||
run = make_run(
|
||||
run_id='run-retry',
|
||||
next_retry_at=utc_now() + timedelta(hours=1),
|
||||
)
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_past_retry_at(
|
||||
async_session, async_session_factory
|
||||
):
|
||||
"""Runs with past next_retry_at are claimable."""
|
||||
automation = make_automation()
|
||||
run = make_run(
|
||||
run_id='run-retry-past',
|
||||
next_retry_at=utc_now() - timedelta(minutes=5),
|
||||
)
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_skips_running_runs(async_session, async_session_factory):
|
||||
"""RUNNING runs are not claimed."""
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='run-running', status='RUNNING', claimed_by='other-executor')
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# recover_stale_runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_recovers_stale(async_session):
|
||||
"""RUNNING runs with expired heartbeats are recovered to PENDING."""
|
||||
automation = make_automation()
|
||||
stale_run = make_run(
|
||||
run_id='stale-1',
|
||||
status='RUNNING',
|
||||
claimed_by='crashed-executor',
|
||||
heartbeat_at=utc_now() - timedelta(minutes=10),
|
||||
retry_count=0,
|
||||
)
|
||||
async_session.add_all([automation, stale_run])
|
||||
await async_session.commit()
|
||||
|
||||
count = await recover_stale_runs(async_session)
|
||||
|
||||
assert count >= 1
|
||||
|
||||
await async_session.refresh(stale_run)
|
||||
assert stale_run.status == 'PENDING'
|
||||
assert stale_run.claimed_by is None
|
||||
assert stale_run.retry_count == 1
|
||||
assert stale_run.next_retry_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_ignores_fresh(async_session):
|
||||
"""RUNNING runs with recent heartbeats are not recovered."""
|
||||
automation = make_automation()
|
||||
fresh_run = make_run(
|
||||
run_id='fresh-1',
|
||||
status='RUNNING',
|
||||
claimed_by='active-executor',
|
||||
heartbeat_at=utc_now() - timedelta(seconds=30),
|
||||
)
|
||||
async_session.add_all([automation, fresh_run])
|
||||
await async_session.commit()
|
||||
|
||||
count = await recover_stale_runs(async_session)
|
||||
|
||||
assert count == 0
|
||||
|
||||
await async_session.refresh(fresh_run)
|
||||
assert fresh_run.status == 'RUNNING'
|
||||
assert fresh_run.claimed_by == 'active-executor'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_ignores_pending(async_session):
|
||||
"""PENDING runs are not affected by recovery."""
|
||||
automation = make_automation()
|
||||
pending_run = make_run(run_id='pending-1', status='PENDING')
|
||||
async_session.add_all([automation, pending_run])
|
||||
await async_session.commit()
|
||||
|
||||
count = await recover_stale_runs(async_session)
|
||||
|
||||
assert count == 0
|
||||
|
||||
await async_session.refresh(pending_run)
|
||||
assert pending_run.status == 'PENDING'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_increments_retry_count(async_session):
|
||||
"""Recovery increments the retry_count."""
|
||||
automation = make_automation()
|
||||
stale_run = make_run(
|
||||
run_id='stale-retry',
|
||||
status='RUNNING',
|
||||
claimed_by='old-executor',
|
||||
heartbeat_at=utc_now() - timedelta(minutes=10),
|
||||
retry_count=2,
|
||||
)
|
||||
async_session.add_all([automation, stale_run])
|
||||
await async_session.commit()
|
||||
|
||||
await recover_stale_runs(async_session)
|
||||
|
||||
await async_session.refresh(stale_run)
|
||||
assert stale_run.retry_count == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _mark_run_failed (error handling)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_run_failed_retries(async_session_factory):
|
||||
"""Failed runs with retries left return to PENDING."""
|
||||
async with async_session_factory() as session:
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='fail-retry', retry_count=0, max_retries=3)
|
||||
session.add_all([automation, run])
|
||||
await session.commit()
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-retry')
|
||||
await _mark_run_failed(run_obj, 'API error', async_session_factory)
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-retry')
|
||||
assert run_obj.status == 'PENDING'
|
||||
assert run_obj.retry_count == 1
|
||||
assert run_obj.error_detail == 'API error'
|
||||
assert run_obj.next_retry_at is not None
|
||||
assert run_obj.claimed_by is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_run_failed_dead_letter(async_session_factory):
|
||||
"""Failed runs that exceed max_retries go to DEAD_LETTER."""
|
||||
async with async_session_factory() as session:
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='fail-dead', retry_count=2, max_retries=3)
|
||||
session.add_all([automation, run])
|
||||
await session.commit()
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-dead')
|
||||
await _mark_run_failed(run_obj, 'Final failure', async_session_factory)
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-dead')
|
||||
assert run_obj.status == 'DEAD_LETTER'
|
||||
assert run_obj.retry_count == 3
|
||||
assert run_obj.error_detail == 'Final failure'
|
||||
assert run_obj.completed_at is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_terminal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_terminal_stopped():
|
||||
assert is_terminal({'status': 'STOPPED'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_error():
|
||||
assert is_terminal({'status': 'ERROR'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_completed():
|
||||
assert is_terminal({'status': 'COMPLETED'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_cancelled():
|
||||
assert is_terminal({'status': 'CANCELLED'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_running():
|
||||
assert is_terminal({'status': 'RUNNING'}) is False
|
||||
|
||||
|
||||
def test_is_terminal_empty():
|
||||
assert is_terminal({}) is False
|
||||
|
||||
|
||||
def test_is_terminal_case_insensitive():
|
||||
assert is_terminal({'status': 'stopped'}) is True
|
||||
assert is_terminal({'status': 'Completed'}) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_automations — None payload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_none_payload(async_session):
|
||||
"""Events with None payload return empty list (data corruption guard)."""
|
||||
event = make_event(source_type='cron')
|
||||
event.payload = None
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: event → run creation → claim
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_integration_event_to_run_to_claim(
|
||||
async_session_factory,
|
||||
):
|
||||
"""Full flow: create event + automation → process_new_events → claim_and_execute_runs.
|
||||
|
||||
Uses a real SQLite database; only the external API client is mocked.
|
||||
"""
|
||||
# 1. Seed an automation and a NEW event
|
||||
async with async_session_factory() as session:
|
||||
automation = make_automation(automation_id='integ-auto')
|
||||
event = make_event(
|
||||
source_type='cron',
|
||||
payload={'automation_id': 'integ-auto'},
|
||||
dedup_key='integ-dedup',
|
||||
)
|
||||
session.add_all([automation, event])
|
||||
await session.commit()
|
||||
event_id = event.id
|
||||
|
||||
# 2. Process inbox — should match and create a PENDING run
|
||||
async with async_session_factory() as session:
|
||||
processed = await process_new_events(session)
|
||||
|
||||
assert processed == 1
|
||||
|
||||
# Verify event is PROCESSED and run was created
|
||||
async with async_session_factory() as session:
|
||||
evt = await session.get(AutomationEvent, event_id)
|
||||
assert evt.status == 'PROCESSED'
|
||||
|
||||
runs = (await session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 1
|
||||
run = runs[0]
|
||||
assert run.automation_id == 'integ-auto'
|
||||
assert run.status == 'PENDING'
|
||||
assert run.event_payload == {'automation_id': 'integ-auto'}
|
||||
|
||||
# 3. Claim the run — mock execute_run to avoid real API calls
|
||||
api_client = AsyncMock()
|
||||
|
||||
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
|
||||
async with async_session_factory() as session:
|
||||
claimed = await claim_and_execute_runs(
|
||||
session, 'executor-integ', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is True
|
||||
|
||||
# 4. Verify the run moved to RUNNING with correct executor
|
||||
async with async_session_factory() as session:
|
||||
runs = (await session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 1
|
||||
run = runs[0]
|
||||
assert run.status == 'RUNNING'
|
||||
assert run.claimed_by == 'executor-integ'
|
||||
assert run.started_at is not None
|
||||
assert run.heartbeat_at is not None
|
||||
104
enterprise/tests/unit/services/test_automation_file_generator.py
Normal file
104
enterprise/tests/unit/services/test_automation_file_generator.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for automation file generator."""
|
||||
|
||||
import ast
|
||||
|
||||
from services.automation_config import extract_config, validate_config
|
||||
from services.automation_file_generator import generate_automation_file
|
||||
|
||||
|
||||
class TestGenerateAutomationFile:
|
||||
def test_generates_valid_python(self):
|
||||
source = generate_automation_file(
|
||||
name='Daily Report',
|
||||
schedule='0 9 * * 1',
|
||||
timezone='UTC',
|
||||
prompt='Generate the daily status report.',
|
||||
)
|
||||
# Must parse without error
|
||||
ast.parse(source)
|
||||
|
||||
def test_contains_config(self):
|
||||
source = generate_automation_file(
|
||||
name='Test Automation',
|
||||
schedule='*/5 * * * *',
|
||||
timezone='America/New_York',
|
||||
prompt='Do something useful.',
|
||||
)
|
||||
cfg = extract_config(source)
|
||||
assert cfg['name'] == 'Test Automation'
|
||||
assert cfg['triggers']['cron']['schedule'] == '*/5 * * * *'
|
||||
assert cfg['triggers']['cron']['timezone'] == 'America/New_York'
|
||||
|
||||
def test_round_trip(self):
|
||||
"""Generate → extract → validate must succeed."""
|
||||
source = generate_automation_file(
|
||||
name='Round Trip',
|
||||
schedule='30 14 * * 0',
|
||||
timezone='Europe/London',
|
||||
prompt='Weekly summary please.',
|
||||
)
|
||||
cfg = extract_config(source)
|
||||
model = validate_config(cfg)
|
||||
assert model.name == 'Round Trip'
|
||||
assert model.triggers.cron is not None
|
||||
assert model.triggers.cron.schedule == '30 14 * * 0'
|
||||
assert model.triggers.cron.timezone == 'Europe/London'
|
||||
|
||||
def test_contains_prompt(self):
|
||||
source = generate_automation_file(
|
||||
name='Prompt Test',
|
||||
schedule='0 0 * * *',
|
||||
timezone='UTC',
|
||||
prompt='Hello world!',
|
||||
)
|
||||
assert 'Hello world!' in source
|
||||
|
||||
def test_contains_docstring(self):
|
||||
source = generate_automation_file(
|
||||
name='Doc Test',
|
||||
schedule='0 0 * * *',
|
||||
timezone='UTC',
|
||||
prompt='test',
|
||||
)
|
||||
assert 'Doc Test' in source
|
||||
assert 'auto-generated automation' in source
|
||||
|
||||
def test_special_characters_in_prompt(self):
|
||||
source = generate_automation_file(
|
||||
name='Special Chars',
|
||||
schedule='0 0 * * *',
|
||||
timezone='UTC',
|
||||
prompt='Check the "status" of \\n stuff',
|
||||
)
|
||||
# Must still be valid Python
|
||||
ast.parse(source)
|
||||
cfg = extract_config(source)
|
||||
assert cfg['name'] == 'Special Chars'
|
||||
|
||||
def test_triple_quotes_in_name(self):
|
||||
"""Names containing triple quotes must not break the generated file."""
|
||||
source = generate_automation_file(
|
||||
name='Test """Demo""" Name',
|
||||
schedule='0 0 * * *',
|
||||
timezone='UTC',
|
||||
prompt='hello',
|
||||
)
|
||||
# Must still be valid Python
|
||||
ast.parse(source)
|
||||
cfg = extract_config(source)
|
||||
assert 'Demo' in cfg['name'] # name preserved in config
|
||||
|
||||
def test_triple_quotes_in_prompt(self):
|
||||
"""Prompts containing triple quotes must not break the generated file."""
|
||||
source = generate_automation_file(
|
||||
name='Triple Quote Test',
|
||||
schedule='0 0 * * *',
|
||||
timezone='UTC',
|
||||
prompt='Use """triple quotes""" and \'\'\'single triples\'\'\' safely',
|
||||
)
|
||||
# Must parse without error
|
||||
ast.parse(source)
|
||||
cfg = extract_config(source)
|
||||
assert cfg['name'] == 'Triple Quote Test'
|
||||
# The prompt must survive round-trip
|
||||
assert '"""triple quotes"""' in source or "triple quotes" in source
|
||||
@@ -1,185 +0,0 @@
|
||||
"""Tests for OpenHandsAPIClient with mocked HTTP responses."""
|
||||
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from services.openhands_api_client import OpenHandsAPIClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_client():
|
||||
client = OpenHandsAPIClient(base_url='http://test-server:3000')
|
||||
yield client
|
||||
# close handled in tests that need it
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# start_conversation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_sends_correct_request(api_client, respx_mock):
|
||||
"""start_conversation sends properly formatted POST with auth header."""
|
||||
automation_file = b'print("hello")'
|
||||
expected_b64 = base64.b64encode(automation_file).decode()
|
||||
|
||||
route = respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
'app_conversation_id': 'conv-123',
|
||||
'status': 'RUNNING',
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
result = await api_client.start_conversation(
|
||||
api_key='sk-oh-test123',
|
||||
automation_file=automation_file,
|
||||
title='Test Automation',
|
||||
event_payload={'automation_id': 'auto-1'},
|
||||
)
|
||||
|
||||
assert route.called
|
||||
request = route.calls[0].request
|
||||
|
||||
assert request.headers['Authorization'] == 'Bearer sk-oh-test123'
|
||||
|
||||
import json
|
||||
|
||||
body = json.loads(request.content)
|
||||
assert body['automation_file'] == expected_b64
|
||||
assert body['trigger'] == 'automation'
|
||||
assert body['title'] == 'Test Automation'
|
||||
assert body['event_payload'] == {'automation_id': 'auto-1'}
|
||||
|
||||
assert result == {'app_conversation_id': 'conv-123', 'status': 'RUNNING'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_without_event_payload(api_client, respx_mock):
|
||||
"""start_conversation works with event_payload=None."""
|
||||
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(200, json={'app_conversation_id': 'conv-456'})
|
||||
)
|
||||
|
||||
result = await api_client.start_conversation(
|
||||
api_key='sk-oh-test',
|
||||
automation_file=b'code',
|
||||
title='Test',
|
||||
event_payload=None,
|
||||
)
|
||||
|
||||
assert result['app_conversation_id'] == 'conv-456'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_http_error(api_client, respx_mock):
|
||||
"""start_conversation raises on HTTP errors."""
|
||||
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(500, json={'error': 'Internal Server Error'})
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc_info:
|
||||
await api_client.start_conversation(
|
||||
api_key='sk-oh-test',
|
||||
automation_file=b'code',
|
||||
title='Test',
|
||||
)
|
||||
|
||||
assert exc_info.value.response.status_code == 500
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_auth_error(api_client, respx_mock):
|
||||
"""start_conversation raises on 401 Unauthorized."""
|
||||
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(401, json={'error': 'Unauthorized'})
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc_info:
|
||||
await api_client.start_conversation(
|
||||
api_key='bad-key',
|
||||
automation_file=b'code',
|
||||
title='Test',
|
||||
)
|
||||
|
||||
assert exc_info.value.response.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_conversation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_returns_data(api_client, respx_mock):
|
||||
"""get_conversation returns the first conversation from the list."""
|
||||
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json=[
|
||||
{
|
||||
'conversation_id': 'conv-123',
|
||||
'status': 'RUNNING',
|
||||
'title': 'My Automation',
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
result = await api_client.get_conversation('sk-oh-test', 'conv-123')
|
||||
|
||||
assert result is not None
|
||||
assert result['conversation_id'] == 'conv-123'
|
||||
assert result['status'] == 'RUNNING'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_returns_none_when_empty(api_client, respx_mock):
|
||||
"""get_conversation returns None when API returns empty list."""
|
||||
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(200, json=[])
|
||||
)
|
||||
|
||||
result = await api_client.get_conversation('sk-oh-test', 'nonexistent')
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_sends_auth_header(api_client, respx_mock):
|
||||
"""get_conversation sends the correct authorization header."""
|
||||
route = respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(200, json=[])
|
||||
)
|
||||
|
||||
await api_client.get_conversation('sk-oh-mykey', 'conv-1')
|
||||
|
||||
assert route.called
|
||||
request = route.calls[0].request
|
||||
assert request.headers['Authorization'] == 'Bearer sk-oh-mykey'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_http_error(api_client, respx_mock):
|
||||
"""get_conversation raises on HTTP errors."""
|
||||
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(503, text='Service Unavailable')
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await api_client.get_conversation('sk-oh-test', 'conv-1')
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# close
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(api_client):
|
||||
"""close() shuts down the HTTP client without errors."""
|
||||
await api_client.close()
|
||||
184
enterprise/tests/unit/storage/test_automation_models.py
Normal file
184
enterprise/tests/unit/storage/test_automation_models.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Tests for Automation and AutomationRun SQLAlchemy models."""
|
||||
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.automation import Automation, AutomationRun
|
||||
from storage.automation_event import AutomationEvent
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
def _make_engine():
|
||||
engine = create_engine('sqlite://', connect_args={'check_same_thread': False})
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
class TestAutomationModel:
|
||||
def test_create_automation_with_defaults(self):
|
||||
engine = _make_engine()
|
||||
Session = sessionmaker(bind=engine)
|
||||
with Session() as session:
|
||||
auto = Automation(
|
||||
id=uuid.uuid4().hex,
|
||||
user_id='user-1',
|
||||
name='My Automation',
|
||||
config={'name': 'My Automation', 'triggers': {'cron': {'schedule': '0 9 * * 1'}}},
|
||||
trigger_type='cron',
|
||||
file_store_key='automations/user-1/abc.py',
|
||||
)
|
||||
session.add(auto)
|
||||
session.commit()
|
||||
|
||||
fetched = session.get(Automation, auto.id)
|
||||
assert fetched is not None
|
||||
assert fetched.name == 'My Automation'
|
||||
assert fetched.user_id == 'user-1'
|
||||
assert fetched.trigger_type == 'cron'
|
||||
assert fetched.org_id is None
|
||||
assert fetched.last_triggered_at is None
|
||||
|
||||
def test_automation_with_org_id(self):
|
||||
engine = _make_engine()
|
||||
Session = sessionmaker(bind=engine)
|
||||
with Session() as session:
|
||||
auto = Automation(
|
||||
id=uuid.uuid4().hex,
|
||||
user_id='user-1',
|
||||
org_id='org-123',
|
||||
name='Org Automation',
|
||||
config={'name': 'Org Automation'},
|
||||
trigger_type='cron',
|
||||
file_store_key='automations/user-1/xyz.py',
|
||||
)
|
||||
session.add(auto)
|
||||
session.commit()
|
||||
|
||||
fetched = session.get(Automation, auto.id)
|
||||
assert fetched is not None
|
||||
assert fetched.org_id == 'org-123'
|
||||
|
||||
|
||||
class TestAutomationRunModel:
|
||||
def test_create_run_with_defaults(self):
|
||||
engine = _make_engine()
|
||||
Session = sessionmaker(bind=engine)
|
||||
with Session() as session:
|
||||
auto = Automation(
|
||||
id=uuid.uuid4().hex,
|
||||
user_id='user-1',
|
||||
name='Test',
|
||||
config={},
|
||||
trigger_type='cron',
|
||||
file_store_key='test.py',
|
||||
)
|
||||
session.add(auto)
|
||||
session.flush()
|
||||
|
||||
run = AutomationRun(
|
||||
id=uuid.uuid4().hex,
|
||||
automation_id=auto.id,
|
||||
)
|
||||
session.add(run)
|
||||
session.commit()
|
||||
|
||||
fetched = session.get(AutomationRun, run.id)
|
||||
assert fetched is not None
|
||||
assert fetched.automation_id == auto.id
|
||||
assert fetched.conversation_id is None
|
||||
assert fetched.event_id is None
|
||||
assert fetched.claimed_by is None
|
||||
assert fetched.error_detail is None
|
||||
|
||||
def test_run_relationship(self):
|
||||
engine = _make_engine()
|
||||
Session = sessionmaker(bind=engine)
|
||||
with Session() as session:
|
||||
auto = Automation(
|
||||
id=uuid.uuid4().hex,
|
||||
user_id='user-1',
|
||||
name='Test',
|
||||
config={},
|
||||
trigger_type='cron',
|
||||
file_store_key='test.py',
|
||||
)
|
||||
session.add(auto)
|
||||
session.flush()
|
||||
|
||||
run = AutomationRun(
|
||||
id=uuid.uuid4().hex,
|
||||
automation_id=auto.id,
|
||||
)
|
||||
session.add(run)
|
||||
session.commit()
|
||||
|
||||
session.refresh(auto)
|
||||
assert len(auto.runs) == 1
|
||||
assert auto.runs[0].id == run.id
|
||||
assert run.automation.id == auto.id
|
||||
|
||||
|
||||
class TestAutomationEventModel:
|
||||
def test_create_event(self):
|
||||
engine = _make_engine()
|
||||
Session = sessionmaker(bind=engine)
|
||||
with Session() as session:
|
||||
event = AutomationEvent(
|
||||
source_type='cron',
|
||||
payload={'tick': True},
|
||||
dedup_key='cron-2025-01-01T00:00:00',
|
||||
)
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
fetched = session.get(AutomationEvent, event.id)
|
||||
assert fetched is not None
|
||||
assert fetched.source_type == 'cron'
|
||||
assert fetched.payload == {'tick': True}
|
||||
assert fetched.dedup_key == 'cron-2025-01-01T00:00:00'
|
||||
assert fetched.error_detail is None
|
||||
|
||||
def test_event_with_metadata(self):
|
||||
engine = _make_engine()
|
||||
Session = sessionmaker(bind=engine)
|
||||
with Session() as session:
|
||||
event = AutomationEvent(
|
||||
source_type='github',
|
||||
payload={'action': 'opened'},
|
||||
dedup_key='gh-12345',
|
||||
metadata_={'installation_id': 99},
|
||||
)
|
||||
session.add(event)
|
||||
session.commit()
|
||||
|
||||
fetched = session.get(AutomationEvent, event.id)
|
||||
assert fetched is not None
|
||||
assert fetched.metadata_ == {'installation_id': 99}
|
||||
|
||||
def test_dedup_key_unique(self):
|
||||
engine = _make_engine()
|
||||
Session = sessionmaker(bind=engine)
|
||||
import sqlalchemy
|
||||
|
||||
with Session() as session:
|
||||
e1 = AutomationEvent(
|
||||
source_type='cron',
|
||||
payload={},
|
||||
dedup_key='dup-key',
|
||||
)
|
||||
session.add(e1)
|
||||
session.commit()
|
||||
|
||||
with Session() as session:
|
||||
e2 = AutomationEvent(
|
||||
source_type='cron',
|
||||
payload={},
|
||||
dedup_key='dup-key',
|
||||
)
|
||||
session.add(e2)
|
||||
try:
|
||||
session.commit()
|
||||
assert False, 'Expected IntegrityError'
|
||||
except sqlalchemy.exc.IntegrityError:
|
||||
session.rollback()
|
||||
@@ -16,6 +16,7 @@ class ConversationTrigger(Enum):
|
||||
JIRA_DC = 'jira_dc'
|
||||
LINEAR = 'linear'
|
||||
BITBUCKET = 'bitbucket'
|
||||
AUTOMATION = 'automation'
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user