mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
3 Commits
1.5.0
...
auto/crud-api
| Author | SHA1 | Date | |
|---|---|---|---|
| 6ec03098ad | |||
| 20e0ebacf0 | |||
| 3230813a95 |
@@ -27,6 +27,7 @@ from server.middleware import SetAuthCookieMiddleware # noqa: E402
|
||||
from server.rate_limit import setup_rate_limit_handler # noqa: E402
|
||||
from server.routes.api_keys import api_router as api_keys_router # noqa: E402
|
||||
from server.routes.auth import api_router, oauth_router # noqa: E402
|
||||
from server.routes.automations import automation_router # noqa: E402
|
||||
from server.routes.billing import billing_router # noqa: E402
|
||||
from server.routes.email import api_router as email_router # noqa: E402
|
||||
from server.routes.event_webhook import event_webhook_router # noqa: E402
|
||||
@@ -139,6 +140,8 @@ if BITBUCKET_DATA_CENTER_HOST:
|
||||
base_app.include_router(bitbucket_dc_proxy_router)
|
||||
base_app.include_router(email_router) # Add routes for email management
|
||||
base_app.include_router(feedback_router) # Add routes for conversation feedback
|
||||
|
||||
base_app.include_router(automation_router) # Add routes for automation CRUD
|
||||
base_app.include_router(
|
||||
event_webhook_router
|
||||
) # Add routes for Events in nested runtimes
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Pydantic request/response models for automation CRUD API."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateAutomationRequest(BaseModel):
|
||||
"""Simple mode (Phase 1): form input → generated file."""
|
||||
|
||||
name: str = Field(min_length=1, max_length=200)
|
||||
schedule: str # 5-field cron expression
|
||||
timezone: str = 'UTC'
|
||||
prompt: str = Field(min_length=1)
|
||||
repository: str | None = None # e.g., "owner/repo"
|
||||
branch: str | None = None
|
||||
|
||||
|
||||
class UpdateAutomationRequest(BaseModel):
|
||||
name: str | None = None
|
||||
schedule: str | None = None
|
||||
timezone: str | None = None
|
||||
prompt: str | None = None
|
||||
repository: str | None = None
|
||||
branch: str | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class AutomationResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
enabled: bool
|
||||
trigger_type: str
|
||||
config: dict
|
||||
file_url: str | None = None
|
||||
last_triggered_at: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class AutomationRunResponse(BaseModel):
|
||||
id: str
|
||||
automation_id: str
|
||||
conversation_id: str | None = None
|
||||
status: str
|
||||
error_detail: str | None = None
|
||||
started_at: str | None = None
|
||||
completed_at: str | None = None
|
||||
created_at: str
|
||||
|
||||
|
||||
class PaginatedAutomationsResponse(BaseModel):
|
||||
items: list[AutomationResponse]
|
||||
total: int
|
||||
next_page_id: str | None = None
|
||||
|
||||
|
||||
class PaginatedRunsResponse(BaseModel):
|
||||
items: list[AutomationRunResponse]
|
||||
total: int
|
||||
next_page_id: str | None = None
|
||||
@@ -0,0 +1,465 @@
|
||||
"""FastAPI router for automation CRUD API (Phase 1: simple mode only)."""
|
||||
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from services.automation_config import extract_config, validate_config
|
||||
from services.automation_event_publisher import publish_automation_event
|
||||
from services.automation_file_generator import generate_automation_file
|
||||
from sqlalchemy import delete, func, select
|
||||
from storage.automation import Automation, AutomationRun
|
||||
from storage.database import a_session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.shared import file_store
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
from .automation_models import (
|
||||
AutomationResponse,
|
||||
AutomationRunResponse,
|
||||
CreateAutomationRequest,
|
||||
PaginatedAutomationsResponse,
|
||||
PaginatedRunsResponse,
|
||||
UpdateAutomationRequest,
|
||||
)
|
||||
|
||||
automation_router = APIRouter(
|
||||
prefix='/api/v1/automations',
|
||||
tags=['automations'],
|
||||
)
|
||||
|
||||
FILE_STORE_PREFIX = 'automations'
|
||||
|
||||
|
||||
def _file_store_key(automation_id: str) -> str:
|
||||
return f'{FILE_STORE_PREFIX}/{automation_id}/automation.py'
|
||||
|
||||
|
||||
def _paginate(rows: list, limit: int, id_attr: str = 'id') -> tuple[list, str | None]:
|
||||
"""Return (items, next_page_id) from an overfetched result set."""
|
||||
if len(rows) > limit:
|
||||
return rows[:limit], getattr(rows[limit], id_attr)
|
||||
return rows, None
|
||||
|
||||
|
||||
def _automation_to_response(automation: Automation) -> AutomationResponse:
|
||||
return AutomationResponse(
|
||||
id=automation.id,
|
||||
name=automation.name,
|
||||
enabled=automation.enabled,
|
||||
trigger_type=automation.trigger_type,
|
||||
config=automation.config or {},
|
||||
file_url=None,
|
||||
last_triggered_at=(
|
||||
automation.last_triggered_at.isoformat()
|
||||
if automation.last_triggered_at
|
||||
else None
|
||||
),
|
||||
created_at=automation.created_at.isoformat() if automation.created_at else '',
|
||||
updated_at=automation.updated_at.isoformat() if automation.updated_at else '',
|
||||
)
|
||||
|
||||
|
||||
def _run_to_response(run: AutomationRun) -> AutomationRunResponse:
|
||||
return AutomationRunResponse(
|
||||
id=run.id,
|
||||
automation_id=run.automation_id,
|
||||
conversation_id=run.conversation_id,
|
||||
status=run.status,
|
||||
error_detail=run.error_detail,
|
||||
started_at=run.started_at.isoformat() if run.started_at else None,
|
||||
completed_at=run.completed_at.isoformat() if run.completed_at else None,
|
||||
created_at=run.created_at.isoformat() if run.created_at else '',
|
||||
)
|
||||
|
||||
|
||||
def _generate_and_validate_file(
|
||||
name: str,
|
||||
schedule: str,
|
||||
timezone: str,
|
||||
prompt: str,
|
||||
repository: str | None = None,
|
||||
branch: str | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
"""Generate automation file, extract config, validate, and store prompt in config.
|
||||
|
||||
Returns (file_content, config_dict).
|
||||
Raises HTTPException on validation failure.
|
||||
"""
|
||||
file_content = generate_automation_file(
|
||||
name=name,
|
||||
schedule=schedule,
|
||||
timezone=timezone,
|
||||
prompt=prompt,
|
||||
repository=repository,
|
||||
branch=branch,
|
||||
)
|
||||
config = extract_config(file_content)
|
||||
try:
|
||||
validate_config(config)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
detail=f'Invalid automation config: {e}',
|
||||
)
|
||||
# Store prompt in config so DB is the source of truth (not the file)
|
||||
config['prompt'] = prompt
|
||||
return file_content, config
|
||||
|
||||
|
||||
@automation_router.post('', status_code=status.HTTP_201_CREATED)
|
||||
async def create_automation(
|
||||
request: CreateAutomationRequest,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> AutomationResponse:
|
||||
"""Create an automation from simple mode input (Phase 1).
|
||||
|
||||
Generates a .py file, uploads to object store, stores metadata in DB.
|
||||
"""
|
||||
file_content, config = _generate_and_validate_file(
|
||||
name=request.name,
|
||||
schedule=request.schedule,
|
||||
timezone=request.timezone,
|
||||
prompt=request.prompt,
|
||||
repository=request.repository,
|
||||
branch=request.branch,
|
||||
)
|
||||
|
||||
automation_id = uuid.uuid4().hex
|
||||
key = _file_store_key(automation_id)
|
||||
|
||||
try:
|
||||
file_store.write(key, file_content)
|
||||
except Exception:
|
||||
logger.exception('Failed to upload automation file to object store')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to store automation file',
|
||||
)
|
||||
|
||||
automation = Automation(
|
||||
id=automation_id,
|
||||
user_id=user_id,
|
||||
name=request.name,
|
||||
enabled=True,
|
||||
config=config,
|
||||
trigger_type='cron',
|
||||
file_store_key=key,
|
||||
)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
session.add(automation)
|
||||
await session.commit()
|
||||
await session.refresh(automation)
|
||||
|
||||
logger.info(
|
||||
'Created automation',
|
||||
extra={'automation_id': automation_id, 'user_id': user_id},
|
||||
)
|
||||
return _automation_to_response(automation)
|
||||
|
||||
|
||||
@automation_router.get('/search')
|
||||
async def search_automations(
|
||||
user_id: str = Depends(get_user_id),
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Cursor for pagination (automation ID)'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='Max results per page', gt=0, le=100),
|
||||
] = 20,
|
||||
) -> PaginatedAutomationsResponse:
|
||||
"""List automations for the current user, paginated."""
|
||||
async with a_session_maker() as session:
|
||||
base_filter = select(Automation).where(Automation.user_id == user_id)
|
||||
|
||||
# Total count
|
||||
count_q = select(func.count()).select_from(base_filter.subquery())
|
||||
total = (await session.execute(count_q)).scalar() or 0
|
||||
|
||||
# Paginated query ordered by created_at desc
|
||||
query = base_filter.order_by(Automation.created_at.desc())
|
||||
if page_id:
|
||||
cursor_row = (
|
||||
await session.execute(
|
||||
select(Automation.created_at).where(Automation.id == page_id)
|
||||
)
|
||||
).scalar()
|
||||
if cursor_row is not None:
|
||||
query = query.where(Automation.created_at < cursor_row)
|
||||
|
||||
query = query.limit(limit + 1)
|
||||
result = await session.execute(query)
|
||||
rows = list(result.scalars().all())
|
||||
|
||||
items, next_page_id = _paginate(rows, limit)
|
||||
return PaginatedAutomationsResponse(
|
||||
items=[_automation_to_response(a) for a in items],
|
||||
total=total,
|
||||
next_page_id=next_page_id,
|
||||
)
|
||||
|
||||
|
||||
@automation_router.get('/{automation_id}')
|
||||
async def get_automation(
|
||||
automation_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> AutomationResponse:
|
||||
"""Get a single automation by ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(Automation).where(
|
||||
Automation.id == automation_id,
|
||||
Automation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
automation = result.scalars().first()
|
||||
|
||||
if not automation:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Automation not found',
|
||||
)
|
||||
return _automation_to_response(automation)
|
||||
|
||||
|
||||
@automation_router.patch('/{automation_id}')
|
||||
async def update_automation(
|
||||
automation_id: str,
|
||||
request: UpdateAutomationRequest,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> AutomationResponse:
|
||||
"""Update an automation. Re-generates file if prompt/schedule/timezone/name changed."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(Automation).where(
|
||||
Automation.id == automation_id,
|
||||
Automation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
automation = result.scalars().first()
|
||||
if not automation:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Automation not found',
|
||||
)
|
||||
|
||||
updates = {
|
||||
k: v
|
||||
for k, v in request.model_dump(exclude_unset=True).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
file_regen_fields = {'schedule', 'timezone', 'prompt', 'name'}
|
||||
needs_regen = bool(updates.keys() & file_regen_fields)
|
||||
|
||||
if needs_regen:
|
||||
current_config = automation.config or {}
|
||||
current_triggers = current_config.get('triggers', {}).get('cron', {})
|
||||
|
||||
# Merge: use request values if provided, else fall back to current config
|
||||
new_name = updates.get('name', automation.name)
|
||||
new_schedule = updates.get(
|
||||
'schedule', current_triggers.get('schedule', '')
|
||||
)
|
||||
new_timezone = updates.get(
|
||||
'timezone', current_triggers.get('timezone', 'UTC')
|
||||
)
|
||||
prompt = updates.get('prompt', current_config.get('prompt', ''))
|
||||
|
||||
file_content, config = _generate_and_validate_file(
|
||||
name=new_name,
|
||||
schedule=new_schedule,
|
||||
timezone=new_timezone,
|
||||
prompt=prompt,
|
||||
repository=updates.get('repository'),
|
||||
branch=updates.get('branch'),
|
||||
)
|
||||
try:
|
||||
file_store.write(automation.file_store_key, file_content)
|
||||
except Exception:
|
||||
logger.exception('Failed to upload updated automation file')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to store updated automation file',
|
||||
)
|
||||
automation.config = config
|
||||
automation.name = new_name
|
||||
|
||||
if 'name' in updates and not needs_regen:
|
||||
automation.name = updates['name']
|
||||
if 'enabled' in updates:
|
||||
automation.enabled = updates['enabled']
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(automation)
|
||||
|
||||
return _automation_to_response(automation)
|
||||
|
||||
|
||||
@automation_router.delete('/{automation_id}', status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_automation(
|
||||
automation_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> None:
|
||||
"""Delete an automation and all its runs."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(Automation).where(
|
||||
Automation.id == automation_id,
|
||||
Automation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
automation = result.scalars().first()
|
||||
if not automation:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Automation not found',
|
||||
)
|
||||
|
||||
file_key = automation.file_store_key
|
||||
|
||||
# Delete runs first
|
||||
await session.execute(
|
||||
delete(AutomationRun).where(AutomationRun.automation_id == automation_id)
|
||||
)
|
||||
await session.delete(automation)
|
||||
await session.commit()
|
||||
|
||||
# Best-effort cleanup of file store (DB is source of truth)
|
||||
try:
|
||||
file_store.delete(file_key)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
'Failed to delete automation file from object store',
|
||||
extra={'automation_id': automation_id},
|
||||
)
|
||||
|
||||
|
||||
@automation_router.post('/{automation_id}/run', status_code=status.HTTP_202_ACCEPTED)
|
||||
async def trigger_manual_run(
|
||||
automation_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> dict:
|
||||
"""Manually trigger an automation run."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(Automation).where(
|
||||
Automation.id == automation_id,
|
||||
Automation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
automation = result.scalars().first()
|
||||
if not automation:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Automation not found',
|
||||
)
|
||||
|
||||
dedup_key = f'manual-{automation_id}-{uuid.uuid4().hex}'
|
||||
await publish_automation_event(
|
||||
session=session,
|
||||
source_type='manual',
|
||||
payload={'automation_id': automation_id},
|
||||
dedup_key=dedup_key,
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return {'status': 'accepted', 'dedup_key': dedup_key}
|
||||
|
||||
|
||||
@automation_router.get('/{automation_id}/runs')
|
||||
async def list_automation_runs(
|
||||
automation_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Cursor for pagination (run ID)'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='Max results per page', gt=0, le=100),
|
||||
] = 20,
|
||||
) -> PaginatedRunsResponse:
|
||||
"""List runs for an automation, paginated."""
|
||||
# Verify ownership
|
||||
async with a_session_maker() as session:
|
||||
ownership = await session.execute(
|
||||
select(Automation.id).where(
|
||||
Automation.id == automation_id,
|
||||
Automation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
if not ownership.scalar():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Automation not found',
|
||||
)
|
||||
|
||||
base_filter = select(AutomationRun).where(
|
||||
AutomationRun.automation_id == automation_id
|
||||
)
|
||||
|
||||
count_q = select(func.count()).select_from(base_filter.subquery())
|
||||
total = (await session.execute(count_q)).scalar() or 0
|
||||
|
||||
query = base_filter.order_by(AutomationRun.created_at.desc())
|
||||
if page_id:
|
||||
cursor_row = (
|
||||
await session.execute(
|
||||
select(AutomationRun.created_at).where(AutomationRun.id == page_id)
|
||||
)
|
||||
).scalar()
|
||||
if cursor_row is not None:
|
||||
query = query.where(AutomationRun.created_at < cursor_row)
|
||||
|
||||
query = query.limit(limit + 1)
|
||||
result = await session.execute(query)
|
||||
rows = list(result.scalars().all())
|
||||
|
||||
items, next_page_id = _paginate(rows, limit)
|
||||
return PaginatedRunsResponse(
|
||||
items=[_run_to_response(r) for r in items],
|
||||
total=total,
|
||||
next_page_id=next_page_id,
|
||||
)
|
||||
|
||||
|
||||
@automation_router.get('/{automation_id}/runs/{run_id}')
|
||||
async def get_automation_run(
|
||||
automation_id: str,
|
||||
run_id: str,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> AutomationRunResponse:
|
||||
"""Get a single run detail."""
|
||||
async with a_session_maker() as session:
|
||||
# Verify ownership of the automation
|
||||
ownership = await session.execute(
|
||||
select(Automation.id).where(
|
||||
Automation.id == automation_id,
|
||||
Automation.user_id == user_id,
|
||||
)
|
||||
)
|
||||
if not ownership.scalar():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Automation not found',
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(AutomationRun).where(
|
||||
AutomationRun.id == run_id,
|
||||
AutomationRun.automation_id == automation_id,
|
||||
)
|
||||
)
|
||||
run = result.scalars().first()
|
||||
|
||||
if not run:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Run not found',
|
||||
)
|
||||
return _run_to_response(run)
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Automation config extraction and validation.
|
||||
|
||||
NOTE: This is a stub for Task 2 (CRUD API) development.
|
||||
Task 1 (Data Foundation) will provide the full implementation.
|
||||
"""
|
||||
|
||||
import ast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CronTriggerModel(BaseModel):
|
||||
schedule: str = Field(pattern=r'^(\S+\s+){4}\S+$')
|
||||
timezone: str = 'UTC'
|
||||
|
||||
|
||||
class TriggersModel(BaseModel):
|
||||
cron: CronTriggerModel | None = None
|
||||
|
||||
def model_post_init(self, __context: object) -> None:
|
||||
defined = [k for k in ('cron',) if getattr(self, k) is not None]
|
||||
if len(defined) != 1:
|
||||
raise ValueError(f'Exactly one trigger required, got: {defined or "none"}')
|
||||
|
||||
|
||||
class AutomationConfigModel(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=200)
|
||||
triggers: TriggersModel
|
||||
description: str = ''
|
||||
|
||||
|
||||
def extract_config(source: str) -> dict:
|
||||
"""Extract __config__ dict from a Python automation file using AST."""
|
||||
tree = ast.parse(source)
|
||||
for node in ast.iter_child_nodes(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == '__config__':
|
||||
return ast.literal_eval(node.value)
|
||||
if isinstance(node, ast.AnnAssign):
|
||||
if (
|
||||
isinstance(node.target, ast.Name)
|
||||
and node.target.id == '__config__'
|
||||
and node.value is not None
|
||||
):
|
||||
return ast.literal_eval(node.value)
|
||||
raise ValueError('No __config__ dict found in automation file')
|
||||
|
||||
|
||||
def validate_config(config: dict) -> AutomationConfigModel:
|
||||
"""Validate a __config__ dict. Returns parsed model or raises ValidationError."""
|
||||
return AutomationConfigModel.model_validate(config)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Automation event publisher.
|
||||
|
||||
NOTE: This is a stub for Task 2 (CRUD API) development.
|
||||
Task 1 (Data Foundation) will provide the full implementation.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from storage.automation_event import AutomationEvent
|
||||
|
||||
|
||||
async def publish_automation_event(
|
||||
session: Any,
|
||||
source_type: str,
|
||||
payload: dict,
|
||||
dedup_key: str,
|
||||
) -> AutomationEvent:
|
||||
"""Insert a new automation event into the automation_events table."""
|
||||
event = AutomationEvent(
|
||||
source_type=source_type,
|
||||
payload=payload,
|
||||
dedup_key=dedup_key,
|
||||
status='NEW',
|
||||
)
|
||||
session.add(event)
|
||||
return event
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Automation file generator for simple mode (Phase 1).
|
||||
|
||||
NOTE: This is a stub for Task 2 (CRUD API) development.
|
||||
Task 1 (Data Foundation) will provide the full implementation.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
PROMPT_TEMPLATE = '''\
|
||||
"""{name} — auto-generated from form input."""
|
||||
|
||||
__config__ = {config_json}
|
||||
|
||||
import os
|
||||
|
||||
from openhands.sdk import LLM, Conversation
|
||||
from openhands.tools.preset.default import get_default_agent
|
||||
|
||||
llm = LLM(
|
||||
model=os.getenv("LLM_MODEL", "anthropic/claude-sonnet-4-5-20250929"),
|
||||
api_key=os.getenv("LLM_API_KEY"),
|
||||
base_url=os.getenv("LLM_BASE_URL"),
|
||||
)
|
||||
agent = get_default_agent(llm=llm, cli_mode=True)
|
||||
conversation = Conversation(agent=agent, workspace=os.getcwd())
|
||||
|
||||
conversation.send_message({prompt!r})
|
||||
conversation.run()
|
||||
'''
|
||||
|
||||
|
||||
def generate_automation_file(
|
||||
name: str,
|
||||
schedule: str,
|
||||
timezone: str,
|
||||
prompt: str,
|
||||
repository: str | None = None,
|
||||
branch: str | None = None,
|
||||
) -> str:
|
||||
"""Generate a Python automation file from form input."""
|
||||
config: dict = {
|
||||
'name': name,
|
||||
'triggers': {'cron': {'schedule': schedule, 'timezone': timezone}},
|
||||
}
|
||||
return PROMPT_TEMPLATE.format(
|
||||
name=name,
|
||||
config_json=json.dumps(config, indent=4),
|
||||
prompt=prompt,
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
"""SQLAlchemy models for automations.
|
||||
|
||||
NOTE: This is a stub for Task 2 (CRUD API) development.
|
||||
Task 1 (Data Foundation) will provide the full implementation.
|
||||
"""
|
||||
|
||||
from sqlalchemy import JSON, Boolean, Column, DateTime, String
|
||||
from sqlalchemy.sql import func
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class Automation(Base): # type: ignore
|
||||
__tablename__ = 'automations'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False, index=True)
|
||||
org_id = Column(String, nullable=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
enabled = Column(Boolean, nullable=False, default=True)
|
||||
config = Column(JSON, nullable=False)
|
||||
trigger_type = Column(String, nullable=False)
|
||||
file_store_key = Column(String, nullable=False)
|
||||
last_triggered_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
|
||||
|
||||
class AutomationRun(Base): # type: ignore
|
||||
__tablename__ = 'automation_runs'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
automation_id = Column(String, nullable=False, index=True)
|
||||
conversation_id = Column(String, nullable=True)
|
||||
status = Column(String, nullable=False, default='PENDING')
|
||||
error_detail = Column(String, nullable=True)
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""SQLAlchemy model for automation events.
|
||||
|
||||
NOTE: This is a stub for Task 2 (CRUD API) development.
|
||||
Task 1 (Data Foundation) will provide the full implementation.
|
||||
"""
|
||||
|
||||
from sqlalchemy import JSON, BigInteger, Column, DateTime, String
|
||||
from sqlalchemy.sql import func
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class AutomationEvent(Base): # type: ignore
|
||||
__tablename__ = 'automation_events'
|
||||
|
||||
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
||||
source_type = Column(String, nullable=False)
|
||||
payload = Column(JSON, nullable=False)
|
||||
metadata_ = Column('metadata', JSON, nullable=True)
|
||||
dedup_key = Column(String, nullable=False, unique=True)
|
||||
status = Column(String, nullable=False, default='NEW')
|
||||
error_detail = Column(String, nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
processed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
@@ -0,0 +1,653 @@
|
||||
"""Unit tests for automation CRUD API routes."""
|
||||
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.testclient import TestClient
|
||||
from server.routes.automations import automation_router
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
TEST_USER_ID = str(uuid.uuid4())
|
||||
OTHER_USER_ID = str(uuid.uuid4())
|
||||
|
||||
|
||||
def _make_automation(
|
||||
automation_id: str | None = None,
|
||||
user_id: str = TEST_USER_ID,
|
||||
name: str = 'Test Automation',
|
||||
enabled: bool = True,
|
||||
trigger_type: str = 'cron',
|
||||
schedule: str = '0 9 * * 5',
|
||||
timezone: str = 'UTC',
|
||||
file_store_key: str | None = None,
|
||||
):
|
||||
auto_id = automation_id or uuid.uuid4().hex
|
||||
mock = MagicMock()
|
||||
mock.id = auto_id
|
||||
mock.user_id = user_id
|
||||
mock.name = name
|
||||
mock.enabled = enabled
|
||||
mock.trigger_type = trigger_type
|
||||
mock.config = {
|
||||
'name': name,
|
||||
'triggers': {'cron': {'schedule': schedule, 'timezone': timezone}},
|
||||
}
|
||||
mock.file_store_key = file_store_key or f'automations/{auto_id}/automation.py'
|
||||
mock.last_triggered_at = None
|
||||
mock.created_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
mock.updated_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
return mock
|
||||
|
||||
|
||||
def _make_run(
|
||||
run_id: str | None = None,
|
||||
automation_id: str = 'auto-1',
|
||||
conversation_id: str | None = None,
|
||||
run_status: str = 'PENDING',
|
||||
):
|
||||
rid = run_id or uuid.uuid4().hex
|
||||
mock = MagicMock()
|
||||
mock.id = rid
|
||||
mock.automation_id = automation_id
|
||||
mock.conversation_id = conversation_id
|
||||
mock.status = run_status
|
||||
mock.error_detail = None
|
||||
mock.started_at = None
|
||||
mock.completed_at = None
|
||||
mock.created_at = datetime(2026, 1, 2, tzinfo=UTC)
|
||||
return mock
|
||||
|
||||
|
||||
# --- Helpers to mock async DB sessions ---
|
||||
|
||||
|
||||
def _mock_session_with_results(results_by_call):
|
||||
"""Create a mock async session that returns preconfigured results.
|
||||
|
||||
results_by_call: list of values; each session.execute() returns
|
||||
the next value wrapped in a mock result.
|
||||
"""
|
||||
call_index = [0]
|
||||
|
||||
session = AsyncMock()
|
||||
|
||||
async def _execute(stmt):
|
||||
idx = call_index[0]
|
||||
call_index[0] += 1
|
||||
val = results_by_call[idx] if idx < len(results_by_call) else None
|
||||
result_mock = MagicMock()
|
||||
if isinstance(val, list):
|
||||
result_mock.scalars.return_value.all.return_value = val
|
||||
result_mock.scalars.return_value.first.return_value = (
|
||||
val[0] if val else None
|
||||
)
|
||||
result_mock.scalar.return_value = len(val)
|
||||
elif val is None:
|
||||
result_mock.scalars.return_value.first.return_value = None
|
||||
result_mock.scalars.return_value.all.return_value = []
|
||||
result_mock.scalar.return_value = None
|
||||
else:
|
||||
result_mock.scalars.return_value.first.return_value = val
|
||||
result_mock.scalar.return_value = val
|
||||
return result_mock
|
||||
|
||||
session.execute = AsyncMock(side_effect=_execute)
|
||||
session.commit = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
session.delete = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _session_ctx(session):
|
||||
yield session
|
||||
|
||||
|
||||
# --- Fixtures ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
"""Create a test FastAPI app with automation routes and mocked auth."""
|
||||
app = FastAPI()
|
||||
app.include_router(automation_router)
|
||||
|
||||
def mock_get_user_id():
|
||||
return TEST_USER_ID
|
||||
|
||||
app.dependency_overrides[get_user_id] = mock_get_user_id
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_app):
|
||||
return TestClient(mock_app)
|
||||
|
||||
|
||||
# --- Test: POST /api/v1/automations ---
|
||||
|
||||
|
||||
class TestCreateAutomation:
|
||||
def test_create_success(self, client):
|
||||
"""POST with valid input → 201 with AutomationResponse."""
|
||||
mock_session = _mock_session_with_results([])
|
||||
|
||||
async def fake_refresh(obj):
|
||||
obj.created_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
obj.updated_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
obj.last_triggered_at = None
|
||||
|
||||
mock_session.refresh = AsyncMock(side_effect=fake_refresh)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.automations.generate_automation_file',
|
||||
return_value='__config__ = {"name": "Test", "triggers": {"cron": {"schedule": "0 9 * * 5"}}}',
|
||||
),
|
||||
patch(
|
||||
'server.routes.automations.extract_config',
|
||||
return_value={
|
||||
'name': 'Test',
|
||||
'triggers': {'cron': {'schedule': '0 9 * * 5'}},
|
||||
},
|
||||
),
|
||||
patch('server.routes.automations.validate_config'),
|
||||
patch('server.routes.automations.file_store') as mock_fs,
|
||||
patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
),
|
||||
):
|
||||
response = client.post(
|
||||
'/api/v1/automations',
|
||||
json={
|
||||
'name': 'Test',
|
||||
'schedule': '0 9 * * 5',
|
||||
'prompt': 'Summarize PRs',
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data['name'] == 'Test'
|
||||
assert data['enabled'] is True
|
||||
assert data['trigger_type'] == 'cron'
|
||||
assert 'id' in data
|
||||
mock_fs.write.assert_called_once()
|
||||
|
||||
def test_create_missing_name(self, client):
|
||||
"""POST with missing name → 422."""
|
||||
response = client.post(
|
||||
'/api/v1/automations',
|
||||
json={'schedule': '0 9 * * 5', 'prompt': 'Test'},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_create_empty_name(self, client):
|
||||
"""POST with empty name → 422."""
|
||||
response = client.post(
|
||||
'/api/v1/automations',
|
||||
json={'name': '', 'schedule': '0 9 * * 5', 'prompt': 'Test'},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_create_missing_prompt(self, client):
|
||||
"""POST with missing prompt → 422."""
|
||||
response = client.post(
|
||||
'/api/v1/automations',
|
||||
json={'name': 'Test', 'schedule': '0 9 * * 5'},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_create_invalid_config_rejected(self, client):
|
||||
"""POST where validate_config raises → 422."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.automations.generate_automation_file',
|
||||
return_value='__config__ = {}',
|
||||
),
|
||||
patch(
|
||||
'server.routes.automations.extract_config',
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
'server.routes.automations.validate_config',
|
||||
side_effect=ValueError('Invalid cron expression'),
|
||||
),
|
||||
):
|
||||
response = client.post(
|
||||
'/api/v1/automations',
|
||||
json={
|
||||
'name': 'Bad Cron',
|
||||
'schedule': 'not-a-cron',
|
||||
'prompt': 'Test',
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert 'Invalid automation config' in response.json()['detail']
|
||||
|
||||
|
||||
# --- Test: GET /api/v1/automations/search ---
|
||||
|
||||
|
||||
class TestSearchAutomations:
|
||||
def test_list_returns_user_automations(self, client):
|
||||
"""GET /search → returns only current user's automations."""
|
||||
a1 = _make_automation(name='Auto 1')
|
||||
a2 = _make_automation(name='Auto 2')
|
||||
|
||||
# Session calls: count query → 2, paginated query → [a1, a2]
|
||||
mock_session = _mock_session_with_results([2, [a1, a2]])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.get('/api/v1/automations/search')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['total'] == 2
|
||||
assert len(data['items']) == 2
|
||||
assert data['items'][0]['name'] == 'Auto 1'
|
||||
|
||||
def test_list_empty(self, client):
|
||||
"""GET /search when no automations → empty list."""
|
||||
mock_session = _mock_session_with_results([0, []])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.get('/api/v1/automations/search')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['total'] == 0
|
||||
assert data['items'] == []
|
||||
assert data['next_page_id'] is None
|
||||
|
||||
|
||||
# --- Test: GET /api/v1/automations/{id} ---
|
||||
|
||||
|
||||
class TestGetAutomation:
|
||||
def test_get_existing(self, client):
|
||||
"""GET existing automation → 200."""
|
||||
auto = _make_automation(automation_id='auto-123')
|
||||
mock_session = _mock_session_with_results([auto])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.get('/api/v1/automations/auto-123')
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()['id'] == 'auto-123'
|
||||
|
||||
def test_get_nonexistent(self, client):
|
||||
"""GET non-existent automation → 404."""
|
||||
mock_session = _mock_session_with_results([None])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.get('/api/v1/automations/does-not-exist')
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
# --- Test: PATCH /api/v1/automations/{id} ---
|
||||
|
||||
|
||||
class TestUpdateAutomation:
|
||||
def test_update_name_and_enabled(self, client):
|
||||
"""PATCH with name + enabled → updates fields, returns 200."""
|
||||
auto = _make_automation(automation_id='auto-123')
|
||||
mock_session = _mock_session_with_results([auto])
|
||||
|
||||
async def fake_refresh(obj):
|
||||
obj.name = 'Updated Name'
|
||||
obj.enabled = False
|
||||
obj.id = 'auto-123'
|
||||
obj.trigger_type = 'cron'
|
||||
obj.config = auto.config
|
||||
obj.last_triggered_at = None
|
||||
obj.created_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
obj.updated_at = datetime(2026, 1, 2, tzinfo=UTC)
|
||||
|
||||
mock_session.refresh = AsyncMock(side_effect=fake_refresh)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
),
|
||||
patch(
|
||||
'server.routes.automations.generate_automation_file',
|
||||
return_value='__config__ = {}',
|
||||
),
|
||||
patch(
|
||||
'server.routes.automations.extract_config',
|
||||
return_value=auto.config,
|
||||
),
|
||||
patch('server.routes.automations.validate_config'),
|
||||
patch('server.routes.automations.file_store') as mock_fs,
|
||||
):
|
||||
mock_fs.read.return_value = (
|
||||
'conversation.send_message("old prompt")\nconversation.run()'
|
||||
)
|
||||
response = client.patch(
|
||||
'/api/v1/automations/auto-123',
|
||||
json={'name': 'Updated Name', 'enabled': False},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['name'] == 'Updated Name'
|
||||
assert data['enabled'] is False
|
||||
|
||||
def test_update_nonexistent(self, client):
|
||||
"""PATCH non-existent → 404."""
|
||||
mock_session = _mock_session_with_results([None])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.patch(
|
||||
'/api/v1/automations/nope',
|
||||
json={'name': 'X'},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_update_prompt_regenerates_file(self, client):
|
||||
"""PATCH with new prompt → re-generates file and uploads."""
|
||||
auto = _make_automation(automation_id='auto-123')
|
||||
mock_session = _mock_session_with_results([auto])
|
||||
|
||||
async def fake_refresh(obj):
|
||||
obj.id = 'auto-123'
|
||||
obj.name = auto.name
|
||||
obj.enabled = auto.enabled
|
||||
obj.trigger_type = 'cron'
|
||||
obj.config = auto.config
|
||||
obj.last_triggered_at = None
|
||||
obj.created_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
obj.updated_at = datetime(2026, 1, 2, tzinfo=UTC)
|
||||
|
||||
mock_session.refresh = AsyncMock(side_effect=fake_refresh)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
),
|
||||
patch(
|
||||
'server.routes.automations.generate_automation_file',
|
||||
return_value='__config__ = {}',
|
||||
) as mock_gen,
|
||||
patch(
|
||||
'server.routes.automations.extract_config',
|
||||
return_value=auto.config,
|
||||
),
|
||||
patch('server.routes.automations.validate_config'),
|
||||
patch('server.routes.automations.file_store') as mock_fs,
|
||||
):
|
||||
response = client.patch(
|
||||
'/api/v1/automations/auto-123',
|
||||
json={'prompt': 'New prompt text'},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_gen.assert_called_once()
|
||||
mock_fs.write.assert_called_once()
|
||||
|
||||
|
||||
# --- Test: DELETE /api/v1/automations/{id} ---
|
||||
|
||||
|
||||
class TestDeleteAutomation:
|
||||
def test_delete_existing(self, client):
|
||||
"""DELETE existing → 204."""
|
||||
auto = _make_automation(automation_id='auto-123')
|
||||
# First execute: select automation, second: delete runs
|
||||
mock_session = _mock_session_with_results([auto, None])
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
),
|
||||
patch('server.routes.automations.file_store'),
|
||||
):
|
||||
response = client.delete('/api/v1/automations/auto-123')
|
||||
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
def test_delete_nonexistent(self, client):
|
||||
"""DELETE non-existent → 404."""
|
||||
mock_session = _mock_session_with_results([None])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.delete('/api/v1/automations/nope')
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
# --- Test: POST /api/v1/automations/{id}/run ---
|
||||
|
||||
|
||||
class TestManualTrigger:
|
||||
def test_manual_trigger_success(self, client):
|
||||
"""POST .../run on existing automation → 202."""
|
||||
auto = _make_automation(automation_id='auto-123')
|
||||
mock_session = _mock_session_with_results([auto])
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
),
|
||||
patch(
|
||||
'server.routes.automations.publish_automation_event',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_pub,
|
||||
):
|
||||
response = client.post('/api/v1/automations/auto-123/run')
|
||||
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
data = response.json()
|
||||
assert data['status'] == 'accepted'
|
||||
assert 'dedup_key' in data
|
||||
assert data['dedup_key'].startswith('manual-auto-123-')
|
||||
mock_pub.assert_called_once()
|
||||
|
||||
def test_manual_trigger_nonexistent(self, client):
|
||||
"""POST .../run on non-existent → 404."""
|
||||
mock_session = _mock_session_with_results([None])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.post('/api/v1/automations/nope/run')
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
# --- Test: GET /api/v1/automations/{id}/runs ---
|
||||
|
||||
|
||||
class TestListRuns:
|
||||
def test_list_runs_success(self, client):
|
||||
"""GET .../runs → paginated list."""
|
||||
r1 = _make_run(run_id='run-1', automation_id='auto-123')
|
||||
r2 = _make_run(run_id='run-2', automation_id='auto-123')
|
||||
|
||||
# Calls: ownership check → 'auto-123', count → 2, paginated → [r1, r2]
|
||||
mock_session = _mock_session_with_results(['auto-123', 2, [r1, r2]])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.get('/api/v1/automations/auto-123/runs')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['total'] == 2
|
||||
assert len(data['items']) == 2
|
||||
|
||||
def test_list_runs_automation_not_found(self, client):
|
||||
"""GET .../runs for non-existent automation → 404."""
|
||||
mock_session = _mock_session_with_results([None])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.get('/api/v1/automations/nope/runs')
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
# --- Test: GET /api/v1/automations/{id}/runs/{run_id} ---
|
||||
|
||||
|
||||
class TestGetRun:
|
||||
def test_get_run_success(self, client):
|
||||
"""GET single run → 200."""
|
||||
run = _make_run(run_id='run-1', automation_id='auto-123')
|
||||
|
||||
# Calls: ownership check → 'auto-123', select run → run
|
||||
mock_session = _mock_session_with_results(['auto-123', run])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.get('/api/v1/automations/auto-123/runs/run-1')
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()['id'] == 'run-1'
|
||||
|
||||
def test_get_run_not_found(self, client):
|
||||
"""GET non-existent run → 404."""
|
||||
mock_session = _mock_session_with_results(['auto-123', None])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.get('/api/v1/automations/auto-123/runs/nope')
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_get_run_automation_not_found(self, client):
|
||||
"""GET run for non-existent automation → 404."""
|
||||
mock_session = _mock_session_with_results([None])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = client.get('/api/v1/automations/nope/runs/run-1')
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
# --- Test: User Isolation (security) ---
|
||||
|
||||
|
||||
class TestUserIsolation:
|
||||
"""Verify that user A cannot access, update, or delete user B's automations.
|
||||
|
||||
The routes filter by user_id from the auth dependency, so automations owned by
|
||||
another user should never be returned (the DB query uses WHERE user_id = <caller>).
|
||||
We simulate this by having the mock session return None for cross-user lookups.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def other_user_app(self):
|
||||
"""App configured to authenticate as OTHER_USER_ID."""
|
||||
app = FastAPI()
|
||||
app.include_router(automation_router)
|
||||
|
||||
def mock_get_other_user_id():
|
||||
return OTHER_USER_ID
|
||||
|
||||
app.dependency_overrides[get_user_id] = mock_get_other_user_id
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def other_client(self, other_user_app):
|
||||
return TestClient(other_user_app)
|
||||
|
||||
def test_cannot_get_other_users_automation(self, other_client):
|
||||
"""User B cannot GET user A's automation → 404."""
|
||||
# The query filters by user_id=OTHER_USER_ID, so it won't find user A's row
|
||||
mock_session = _mock_session_with_results([None])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = other_client.get('/api/v1/automations/auto-owned-by-a')
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_cannot_update_other_users_automation(self, other_client):
|
||||
"""User B cannot PATCH user A's automation → 404."""
|
||||
mock_session = _mock_session_with_results([None])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = other_client.patch(
|
||||
'/api/v1/automations/auto-owned-by-a',
|
||||
json={'name': 'Hijacked'},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_cannot_delete_other_users_automation(self, other_client):
|
||||
"""User B cannot DELETE user A's automation → 404."""
|
||||
mock_session = _mock_session_with_results([None])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = other_client.delete('/api/v1/automations/auto-owned-by-a')
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
def test_search_returns_empty_for_other_user(self, other_client):
|
||||
"""User B's search returns empty even if user A has automations."""
|
||||
# count=0, rows=[]
|
||||
mock_session = _mock_session_with_results([0, []])
|
||||
|
||||
with patch(
|
||||
'server.routes.automations.a_session_maker',
|
||||
return_value=_session_ctx(mock_session),
|
||||
):
|
||||
response = other_client.get('/api/v1/automations/search')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['total'] == 0
|
||||
assert data['items'] == []
|
||||
@@ -0,0 +1,264 @@
|
||||
"""Integration tests for automation CRUD API using a real in-memory SQLite database.
|
||||
|
||||
These tests exercise actual SQL queries (list, get, create+verify, pagination, delete)
|
||||
rather than mocking the database layer.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from server.routes.automations import automation_router
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from storage.automation import Automation, AutomationRun
|
||||
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
TEST_USER_ID = 'integration-test-user'
|
||||
OTHER_USER_ID = 'other-user'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
engine = create_async_engine('sqlite+aiosqlite://', echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_maker(db_engine):
|
||||
return async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(session_maker):
|
||||
"""FastAPI app wired to a real SQLite database."""
|
||||
app = FastAPI()
|
||||
app.include_router(automation_router)
|
||||
app.dependency_overrides[get_user_id] = lambda: TEST_USER_ID
|
||||
|
||||
@asynccontextmanager
|
||||
async def _session_ctx():
|
||||
async with session_maker() as session:
|
||||
yield session
|
||||
|
||||
with patch('server.routes.automations.a_session_maker', _session_ctx):
|
||||
yield app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(app):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url='http://test') as c:
|
||||
yield c
|
||||
|
||||
|
||||
def _make_automation_obj(
|
||||
user_id: str = TEST_USER_ID,
|
||||
name: str = 'Test Auto',
|
||||
created_at: datetime | None = None,
|
||||
**kwargs,
|
||||
) -> Automation:
|
||||
return Automation(
|
||||
id=kwargs.get('automation_id', uuid.uuid4().hex),
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
enabled=kwargs.get('enabled', True),
|
||||
config=kwargs.get(
|
||||
'config',
|
||||
{
|
||||
'name': name,
|
||||
'triggers': {'cron': {'schedule': '0 9 * * 5', 'timezone': 'UTC'}},
|
||||
'prompt': 'Do something',
|
||||
},
|
||||
),
|
||||
trigger_type='cron',
|
||||
file_store_key=kwargs.get('file_store_key', f'automations/{uuid.uuid4().hex}/automation.py'),
|
||||
created_at=created_at or datetime.now(UTC),
|
||||
updated_at=created_at or datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
# ---------- Test: list (search) returns correct results ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_returns_user_automations(client, session_maker):
|
||||
"""GET /search returns only automations owned by the requesting user."""
|
||||
async with session_maker() as session:
|
||||
a1 = _make_automation_obj(name='Auto A', created_at=datetime(2026, 1, 1, tzinfo=UTC))
|
||||
a2 = _make_automation_obj(name='Auto B', created_at=datetime(2026, 1, 2, tzinfo=UTC))
|
||||
a_other = _make_automation_obj(user_id=OTHER_USER_ID, name='Other User Auto')
|
||||
session.add_all([a1, a2, a_other])
|
||||
await session.commit()
|
||||
|
||||
response = await client.get('/api/v1/automations/search')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['total'] == 2
|
||||
assert len(data['items']) == 2
|
||||
names = {item['name'] for item in data['items']}
|
||||
assert names == {'Auto A', 'Auto B'}
|
||||
|
||||
|
||||
# ---------- Test: get returns the right object ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_returns_correct_automation(client, session_maker):
|
||||
"""GET /{id} returns the correct automation by ID."""
|
||||
auto_id = uuid.uuid4().hex
|
||||
async with session_maker() as session:
|
||||
auto = _make_automation_obj(automation_id=auto_id, name='Specific Auto')
|
||||
session.add(auto)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(f'/api/v1/automations/{auto_id}')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['id'] == auto_id
|
||||
assert data['name'] == 'Specific Auto'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_returns_404(client):
|
||||
"""GET /{id} for non-existent automation returns 404."""
|
||||
response = await client.get('/api/v1/automations/does-not-exist')
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ---------- Test: create + verify in DB ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_stores_in_db(client, session_maker):
|
||||
"""POST creates an automation and it's readable from the database."""
|
||||
mock_file_store = MagicMock()
|
||||
config = {
|
||||
'name': 'New Auto',
|
||||
'triggers': {'cron': {'schedule': '0 9 * * 5', 'timezone': 'UTC'}},
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.automations.generate_automation_file',
|
||||
return_value='__config__ = {}',
|
||||
),
|
||||
patch('server.routes.automations.extract_config', return_value=config),
|
||||
patch('server.routes.automations.validate_config'),
|
||||
patch('server.routes.automations.file_store', mock_file_store),
|
||||
):
|
||||
response = await client.post(
|
||||
'/api/v1/automations',
|
||||
json={
|
||||
'name': 'New Auto',
|
||||
'schedule': '0 9 * * 5',
|
||||
'prompt': 'Summarize PRs',
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
created_id = data['id']
|
||||
|
||||
# Verify it's in the DB via the GET endpoint
|
||||
get_response = await client.get(f'/api/v1/automations/{created_id}')
|
||||
assert get_response.status_code == 200
|
||||
assert get_response.json()['name'] == 'New Auto'
|
||||
# Verify prompt is stored in config
|
||||
assert get_response.json()['config'].get('prompt') == 'Summarize PRs'
|
||||
|
||||
|
||||
# ---------- Test: delete actually deletes ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_removes_from_db(client, session_maker):
|
||||
"""DELETE removes the automation from the database."""
|
||||
auto_id = uuid.uuid4().hex
|
||||
async with session_maker() as session:
|
||||
auto = _make_automation_obj(automation_id=auto_id, name='To Delete')
|
||||
session.add(auto)
|
||||
await session.commit()
|
||||
|
||||
mock_file_store = MagicMock()
|
||||
with patch('server.routes.automations.file_store', mock_file_store):
|
||||
response = await client.delete(f'/api/v1/automations/{auto_id}')
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify it's gone
|
||||
get_response = await client.get(f'/api/v1/automations/{auto_id}')
|
||||
assert get_response.status_code == 404
|
||||
|
||||
|
||||
# ---------- Test: pagination actually works ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_returns_correct_pages(client, session_maker):
|
||||
"""Pagination with limit returns correct page sizes and next_page_id."""
|
||||
base_time = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
async with session_maker() as session:
|
||||
for i in range(5):
|
||||
auto = _make_automation_obj(
|
||||
name=f'Auto {i}',
|
||||
created_at=base_time + timedelta(hours=i),
|
||||
)
|
||||
session.add(auto)
|
||||
await session.commit()
|
||||
|
||||
# First page with limit=2
|
||||
response = await client.get('/api/v1/automations/search?limit=2')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['total'] == 5
|
||||
assert len(data['items']) == 2
|
||||
assert data['next_page_id'] is not None
|
||||
|
||||
# Second page using cursor — should return remaining items before cursor
|
||||
next_id = data['next_page_id']
|
||||
response2 = await client.get(f'/api/v1/automations/search?limit=2&page_id={next_id}')
|
||||
assert response2.status_code == 200
|
||||
data2 = response2.json()
|
||||
assert len(data2['items']) == 2
|
||||
|
||||
# Collect all items from both pages and verify no duplicates
|
||||
all_ids = [item['id'] for item in data['items']] + [
|
||||
item['id'] for item in data2['items']
|
||||
]
|
||||
assert len(all_ids) == len(set(all_ids)), 'Pages must not contain duplicate items'
|
||||
|
||||
|
||||
# ---------- Test: user isolation at DB level ----------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(client, session_maker):
|
||||
"""User A cannot see or access User B's automations via actual DB queries."""
|
||||
auto_id = uuid.uuid4().hex
|
||||
async with session_maker() as session:
|
||||
other_auto = _make_automation_obj(
|
||||
automation_id=auto_id,
|
||||
user_id=OTHER_USER_ID,
|
||||
name='Other User Auto',
|
||||
)
|
||||
session.add(other_auto)
|
||||
await session.commit()
|
||||
|
||||
# Should not be found by TEST_USER_ID
|
||||
response = await client.get(f'/api/v1/automations/{auto_id}')
|
||||
assert response.status_code == 404
|
||||
|
||||
# Should not appear in search
|
||||
search_response = await client.get('/api/v1/automations/search')
|
||||
assert search_response.status_code == 200
|
||||
assert search_response.json()['total'] == 0
|
||||
Reference in New Issue
Block a user