mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
v1 for workflows
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .db import BaseDBModel, Gallery, Message, Run, RunStatus, Session, Settings, Team
|
||||
from .db import BaseDBModel, Gallery, Message, Run, RunStatus, Session, Settings, Team, WorkflowDB
|
||||
from .types import (
|
||||
EnvironmentVariable,
|
||||
GalleryComponents,
|
||||
@@ -33,4 +33,5 @@ __all__ = [
|
||||
"Settings",
|
||||
"EnvironmentVariable",
|
||||
"Gallery",
|
||||
"WorkflowDB",
|
||||
]
|
||||
|
||||
@@ -20,6 +20,13 @@ from .types import (
|
||||
TeamResult,
|
||||
)
|
||||
|
||||
# Import WorkflowConfig for workflow storage
|
||||
try:
|
||||
from ..workflow.core import WorkflowConfig
|
||||
except ImportError:
|
||||
# Fallback if workflow system is not available
|
||||
WorkflowConfig = dict
|
||||
|
||||
|
||||
class BaseDBModel(SQLModel, table=False):
|
||||
"""
|
||||
@@ -137,6 +144,34 @@ class Settings(BaseDBModel, table=True):
|
||||
config: Union[SettingsConfig, dict] = Field(default_factory=SettingsConfig, sa_column=Column(JSON))
|
||||
|
||||
|
||||
# --- Workflow system database models ---
|
||||
|
||||
|
||||
class WorkflowDB(BaseDBModel, table=True):
|
||||
"""Database model for storing workflows."""
|
||||
|
||||
__table_args__ = {"sqlite_autoincrement": True}
|
||||
|
||||
name: str = "Unnamed Workflow"
|
||||
description: str = ""
|
||||
|
||||
# Store the serialized WorkflowConfig
|
||||
config: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
|
||||
|
||||
# Optional metadata for organization
|
||||
tags: List[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
|
||||
# Workflow status (for future use if needed)
|
||||
is_active: bool = Field(default=True)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_encoders={
|
||||
datetime: lambda v: v.isoformat(),
|
||||
SecretStr: lambda v: v.get_secret_value(),
|
||||
}
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
|
||||
# --- Evaluation system database models ---
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from .auth.middleware import AuthMiddleware
|
||||
from .config import settings
|
||||
from .deps import cleanup_managers, init_auth_manager, init_managers, register_auth_dependencies
|
||||
from .initialization import AppInitializer
|
||||
from .routes import gallery, mcp, runs, sessions, settingsroute, teams, validation, ws
|
||||
from .routes import gallery, mcp, runs, sessions, settingsroute, teams, validation, workflows, ws
|
||||
|
||||
# Initialize application
|
||||
app_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -155,6 +155,12 @@ api.include_router(
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
api.include_router(
|
||||
workflows.router,
|
||||
tags=["workflows"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
# Version endpoint
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,371 @@
|
||||
# /api/workflows routes
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...datamodel import WorkflowDB
|
||||
from ...workflow.core import Workflow, WorkflowRunner, WorkflowExecution, WorkflowStatus
|
||||
from ..auth.dependencies import get_ws_auth_manager, get_current_user
|
||||
from ..auth.models import User
|
||||
from ..auth.wsauth import WebSocketAuthHandler
|
||||
from ..deps import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CreateWorkflowRequest(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
config: Dict[str, Any]
|
||||
tags: list[str] = []
|
||||
user_id: str
|
||||
|
||||
|
||||
class UpdateWorkflowRequest(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
config: Dict[str, Any] | None = None
|
||||
tags: list[str] | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
|
||||
class CreateWorkflowRunRequest(BaseModel):
|
||||
workflow_id: int | None = None
|
||||
workflow_config: Dict[str, Any] | None = None
|
||||
|
||||
|
||||
# ==================== REST API Routes ====================
|
||||
|
||||
@router.get("/api/workflows")
|
||||
async def list_workflows(
|
||||
user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db=Depends(get_db)
|
||||
) -> Dict[str, Any]:
|
||||
"""List all workflows with optional filters"""
|
||||
try:
|
||||
workflows = db.get(WorkflowDB, filters={"user_id": user_id, "is_active": True})
|
||||
return {"status": True, "data": workflows.data or []}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.post("/api/workflows")
|
||||
async def create_workflow(request: CreateWorkflowRequest, db=Depends(get_db)) -> Dict:
|
||||
"""Create a new workflow"""
|
||||
try:
|
||||
workflow = db.upsert(
|
||||
WorkflowDB(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
config=request.config,
|
||||
tags=request.tags,
|
||||
user_id=request.user_id,
|
||||
is_active=True,
|
||||
),
|
||||
return_json=False,
|
||||
)
|
||||
return {"status": workflow.status, "data": {"workflow_id": workflow.data.id}}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/api/workflows/{workflow_id}")
|
||||
async def get_workflow(workflow_id: int, user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""Get workflow details"""
|
||||
try:
|
||||
workflow = db.get(WorkflowDB, filters={"id": workflow_id, "user_id": user_id}, return_json=False)
|
||||
if not workflow.status or not workflow.data:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
return {"status": True, "data": workflow.data[0]}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.put("/api/workflows/{workflow_id}")
|
||||
async def update_workflow(
|
||||
workflow_id: int, request: UpdateWorkflowRequest, user_id: str, db=Depends(get_db)
|
||||
) -> Dict:
|
||||
"""Update workflow"""
|
||||
try:
|
||||
# First check if workflow exists and belongs to user
|
||||
existing = db.get(WorkflowDB, filters={"id": workflow_id, "user_id": user_id}, return_json=False)
|
||||
if not existing.status or not existing.data:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
workflow = existing.data[0]
|
||||
|
||||
# Update only provided fields
|
||||
if request.name is not None:
|
||||
workflow.name = request.name
|
||||
if request.description is not None:
|
||||
workflow.description = request.description
|
||||
if request.config is not None:
|
||||
workflow.config = request.config
|
||||
if request.tags is not None:
|
||||
workflow.tags = request.tags
|
||||
if request.is_active is not None:
|
||||
workflow.is_active = request.is_active
|
||||
|
||||
updated = db.upsert(workflow, return_json=False)
|
||||
return {"status": updated.status, "data": updated.data}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.delete("/api/workflows/{workflow_id}")
|
||||
async def delete_workflow(workflow_id: int, user_id: str, db=Depends(get_db)) -> Dict:
|
||||
"""Delete workflow (soft delete by setting is_active=False)"""
|
||||
try:
|
||||
# Check if workflow exists and belongs to user
|
||||
workflow = db.get(WorkflowDB, filters={"id": workflow_id, "user_id": user_id}, return_json=False)
|
||||
if not workflow.status or not workflow.data:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Soft delete by setting is_active=False
|
||||
workflow_obj = workflow.data[0]
|
||||
workflow_obj.is_active = False
|
||||
|
||||
result = db.upsert(workflow_obj)
|
||||
return {"status": result.status, "message": "Workflow deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.post("/api/workflows/run")
|
||||
async def create_workflow_run(request: CreateWorkflowRunRequest, db=Depends(get_db)) -> Dict:
|
||||
"""Create an ephemeral workflow run - returns temporary run_id"""
|
||||
try:
|
||||
# Generate ephemeral run ID
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
# If workflow_id provided, fetch the workflow config
|
||||
if request.workflow_id:
|
||||
workflow = db.get(WorkflowDB, filters={"id": request.workflow_id}, return_json=False)
|
||||
if not workflow.status or not workflow.data:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
workflow_config = workflow.data[0].config
|
||||
elif request.workflow_config:
|
||||
# Use provided inline config
|
||||
workflow_config = request.workflow_config
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Either workflow_id or workflow_config must be provided")
|
||||
|
||||
# For now, just return the run_id
|
||||
# The actual workflow execution will happen via WebSocket
|
||||
return {
|
||||
"status": True,
|
||||
"data": {
|
||||
"run_id": run_id,
|
||||
"workflow_config": workflow_config
|
||||
}
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
# ==================== WebSocket Routes ====================
|
||||
|
||||
# Global tracking for ephemeral workflow runs
|
||||
active_workflow_runs: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class WorkflowWebSocketManager:
|
||||
"""Manages WebSocket connections for workflow execution"""
|
||||
|
||||
def __init__(self):
|
||||
self.connections: Dict[str, WebSocket] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, run_id: str) -> bool:
|
||||
"""Connect a WebSocket for a workflow run"""
|
||||
try:
|
||||
await websocket.accept()
|
||||
self.connections[run_id] = websocket
|
||||
logger.info(f"Workflow WebSocket connected for run {run_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect workflow WebSocket: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect(self, run_id: str):
|
||||
"""Disconnect and cleanup"""
|
||||
if run_id in self.connections:
|
||||
del self.connections[run_id]
|
||||
if run_id in active_workflow_runs:
|
||||
del active_workflow_runs[run_id]
|
||||
logger.info(f"Workflow WebSocket disconnected for run {run_id}")
|
||||
|
||||
async def send_message(self, run_id: str, message: Dict[str, Any]):
|
||||
"""Send message to connected WebSocket"""
|
||||
if run_id in self.connections:
|
||||
try:
|
||||
await self.connections[run_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message to run {run_id}: {e}")
|
||||
|
||||
async def execute_workflow(self, run_id: str, workflow_config: Dict[str, Any], initial_input: Any = None):
|
||||
"""Execute a workflow and stream progress"""
|
||||
try:
|
||||
# Create workflow from config
|
||||
workflow = Workflow.load_component(workflow_config)
|
||||
|
||||
# Create and run workflow runner
|
||||
runner = WorkflowRunner()
|
||||
|
||||
# Stream workflow execution events
|
||||
try:
|
||||
async for event in runner.run_stream(workflow, initial_input):
|
||||
# Send the event directly as JSON (with run_id added)
|
||||
event_data = event.model_dump()
|
||||
event_data["run_id"] = run_id # Add run_id for WebSocket context
|
||||
await self.send_message(run_id, event_data)
|
||||
|
||||
except Exception as workflow_error:
|
||||
# Send error message for unexpected errors
|
||||
await self.send_message(run_id, {
|
||||
"type": "workflow_error",
|
||||
"run_id": run_id,
|
||||
"error": str(workflow_error),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Workflow execution error for run {run_id}: {e}")
|
||||
await self.send_message(run_id, {
|
||||
"type": "workflow_error",
|
||||
"run_id": run_id,
|
||||
"error": f"Failed to execute workflow: {str(e)}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
|
||||
# Global workflow WebSocket manager
|
||||
workflow_manager = WorkflowWebSocketManager()
|
||||
|
||||
|
||||
@router.websocket("/api/workflow/ws/{run_id}")
|
||||
async def workflow_websocket(
|
||||
websocket: WebSocket,
|
||||
run_id: str,
|
||||
auth_manager=Depends(get_ws_auth_manager),
|
||||
):
|
||||
"""WebSocket endpoint for workflow execution"""
|
||||
|
||||
try:
|
||||
# Connect websocket
|
||||
connected = await workflow_manager.connect(websocket, run_id)
|
||||
if not connected:
|
||||
return
|
||||
|
||||
# Handle authentication if enabled
|
||||
if auth_manager is not None:
|
||||
ws_auth = WebSocketAuthHandler(auth_manager)
|
||||
success, user = await ws_auth.authenticate(websocket)
|
||||
if not success:
|
||||
logger.warning(f"Authentication failed for workflow WebSocket run {run_id}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"error": "Authentication failed",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
return
|
||||
|
||||
logger.info(f"Workflow WebSocket connection established for run {run_id}")
|
||||
|
||||
# Store run info
|
||||
active_workflow_runs[run_id] = {
|
||||
"created_at": datetime.now(),
|
||||
"status": "connected"
|
||||
}
|
||||
|
||||
raw_message = None # Initialize to avoid unbound variable issue
|
||||
while True:
|
||||
try:
|
||||
raw_message = await websocket.receive_text()
|
||||
message = json.loads(raw_message)
|
||||
|
||||
if message.get("type") == "start":
|
||||
# Handle start message
|
||||
logger.info(f"Received workflow start request for run {run_id}")
|
||||
|
||||
workflow_config = message.get("workflow_config")
|
||||
initial_input = message.get("input")
|
||||
|
||||
if not workflow_config:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"error": "workflow_config is required",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
continue
|
||||
|
||||
# Start workflow execution in background
|
||||
asyncio.create_task(
|
||||
workflow_manager.execute_workflow(run_id, workflow_config, initial_input)
|
||||
)
|
||||
|
||||
elif message.get("type") == "stop":
|
||||
logger.info(f"Received workflow stop request for run {run_id}")
|
||||
# TODO: Implement workflow cancellation
|
||||
await websocket.send_json({
|
||||
"type": "workflow_stopped",
|
||||
"run_id": run_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
break
|
||||
|
||||
elif message.get("type") == "ping":
|
||||
await websocket.send_json({
|
||||
"type": "pong",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {message.get('type')}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON received: {raw_message or 'None'}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"error": "Invalid message format",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"Workflow WebSocket disconnected for run {run_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Workflow WebSocket error: {str(e)}")
|
||||
finally:
|
||||
await workflow_manager.disconnect(run_id)
|
||||
|
||||
|
||||
@router.get("/api/workflow/ws/status/{run_id}")
|
||||
async def get_workflow_run_status(run_id: str):
|
||||
"""Get status of an active workflow run"""
|
||||
if run_id not in active_workflow_runs:
|
||||
return {"status": False, "message": "Run not found"}
|
||||
|
||||
run_info = active_workflow_runs[run_id]
|
||||
return {
|
||||
"status": True,
|
||||
"data": {
|
||||
"run_id": run_id,
|
||||
"created_at": run_info["created_at"].isoformat(),
|
||||
"status": run_info["status"]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Workflow system for autogenstudio.
|
||||
"""
|
||||
|
||||
from .core import Workflow, WorkflowRunner, WorkflowMetadata, StepMetadata
|
||||
from .steps import FunctionStep, EchoStep
|
||||
|
||||
__all__ = ["Workflow", "FunctionStep", "EchoStep", "WorkflowRunner", "WorkflowMetadata", "StepMetadata"]
|
||||
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Core workflow engine components.
|
||||
"""
|
||||
|
||||
from ._workflow import Workflow, BaseWorkflow, WorkflowConfig
|
||||
from ._runner import WorkflowRunner
|
||||
from ._models import (
|
||||
InputType, OutputType, StepStatus, WorkflowStatus,
|
||||
Edge, EdgeCondition, StepExecution, WorkflowExecution,
|
||||
StepMetadata, WorkflowMetadata, Context, WorkflowValidationResult
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Workflow classes
|
||||
"Workflow", "BaseWorkflow", "WorkflowConfig",
|
||||
# Runner
|
||||
"WorkflowRunner",
|
||||
# Models and types
|
||||
"InputType", "OutputType", "StepStatus", "WorkflowStatus",
|
||||
"Edge", "EdgeCondition", "StepExecution", "WorkflowExecution",
|
||||
"StepMetadata", "WorkflowMetadata", "Context", "WorkflowValidationResult"
|
||||
]
|
||||
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Core data models for the workflow system.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypeVar
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from enum import Enum
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Type variables for generic step inputs/outputs
|
||||
InputType = TypeVar("InputType", bound=BaseModel)
|
||||
OutputType = TypeVar("OutputType", bound=BaseModel)
|
||||
|
||||
|
||||
class StepStatus(str, Enum):
|
||||
"""Status of a step in workflow execution."""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class WorkflowStatus(str, Enum):
|
||||
"""Status of workflow execution."""
|
||||
CREATED = "created"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class EdgeCondition(BaseModel):
|
||||
"""Defines conditions for workflow edges."""
|
||||
type: str = Field(default="always", description="Type of condition: always, output_based, state_based")
|
||||
expression: Optional[str] = Field(default=None, description="Python expression to evaluate")
|
||||
field: Optional[str] = Field(default=None, description="Field to check in output or state")
|
||||
value: Optional[Any] = Field(default=None, description="Expected value")
|
||||
operator: Optional[str] = Field(default=None, description="Comparison operator: ==, !=, >, <, in, etc.")
|
||||
|
||||
|
||||
class Edge(BaseModel):
|
||||
"""Represents a connection between workflow steps."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
from_step: str = Field(description="Source step ID")
|
||||
to_step: str = Field(description="Target step ID")
|
||||
condition: EdgeCondition = Field(default_factory=lambda: EdgeCondition())
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class StepExecution(BaseModel):
|
||||
"""Tracks execution details of a step."""
|
||||
step_id: str
|
||||
status: StepStatus = StepStatus.PENDING
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
input_data: Optional[Dict[str, Any]] = None
|
||||
output_data: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
retry_count: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class WorkflowExecution(BaseModel):
|
||||
"""Tracks execution of an entire workflow."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
workflow_id: str
|
||||
status: WorkflowStatus = WorkflowStatus.CREATED
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
state: Dict[str, Any] = Field(default_factory=dict)
|
||||
step_executions: Dict[str, StepExecution] = Field(default_factory=dict)
|
||||
error: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class StepMetadata(BaseModel):
|
||||
"""Metadata for workflow steps."""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
max_retries: int = 0
|
||||
timeout_seconds: Optional[int] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class WorkflowMetadata(BaseModel):
|
||||
"""Metadata for workflows."""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
version: str = "1.0.0"
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
author: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
"""Simple typed context for workflow steps."""
|
||||
|
||||
state: Dict[str, Any] = Field(default_factory=dict, description="Shared mutable workflow state")
|
||||
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
@classmethod
|
||||
def from_state_ref(cls, state_dict: Dict[str, Any]) -> "Context":
|
||||
"""Create Context with direct reference to state dict (no copy)."""
|
||||
# Create instance normally but then replace the state reference
|
||||
instance = cls(state={}) # Initialize with empty dict
|
||||
instance.__dict__['state'] = state_dict # Directly set the dict reference
|
||||
return instance
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a value from workflow state."""
|
||||
return self.state.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""Set a value in workflow state."""
|
||||
self.state[key] = value
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for backward compatibility."""
|
||||
return {
|
||||
'workflow_state': self.state,
|
||||
**self.state # Also include state values directly
|
||||
}
|
||||
|
||||
|
||||
class WorkflowValidationResult(BaseModel):
|
||||
"""Result of workflow validation."""
|
||||
is_valid: bool
|
||||
errors: List[str] = Field(default_factory=list)
|
||||
warnings: List[str] = Field(default_factory=list)
|
||||
has_cycles: bool = False
|
||||
unreachable_steps: List[str] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
# Workflow Event Models for Streaming
|
||||
class WorkflowEventType(str, Enum):
|
||||
"""Types of workflow events."""
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_COMPLETED = "workflow_completed"
|
||||
WORKFLOW_FAILED = "workflow_failed"
|
||||
STEP_STARTED = "step_started"
|
||||
STEP_COMPLETED = "step_completed"
|
||||
STEP_FAILED = "step_failed"
|
||||
EDGE_ACTIVATED = "edge_activated"
|
||||
|
||||
|
||||
class WorkflowEvent(BaseModel):
|
||||
"""Base class for workflow events."""
|
||||
event_type: WorkflowEventType
|
||||
timestamp: datetime
|
||||
workflow_id: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class WorkflowStartedEvent(WorkflowEvent):
|
||||
"""Workflow execution started."""
|
||||
event_type: WorkflowEventType = WorkflowEventType.WORKFLOW_STARTED
|
||||
initial_input: Dict[str, Any]
|
||||
|
||||
|
||||
class WorkflowCompletedEvent(WorkflowEvent):
|
||||
"""Workflow execution completed successfully."""
|
||||
event_type: WorkflowEventType = WorkflowEventType.WORKFLOW_COMPLETED
|
||||
execution: WorkflowExecution
|
||||
|
||||
|
||||
class WorkflowFailedEvent(WorkflowEvent):
|
||||
"""Workflow execution failed."""
|
||||
event_type: WorkflowEventType = WorkflowEventType.WORKFLOW_FAILED
|
||||
error: str
|
||||
execution: Optional[WorkflowExecution] = None
|
||||
|
||||
|
||||
class StepStartedEvent(WorkflowEvent):
|
||||
"""Step execution started."""
|
||||
event_type: WorkflowEventType = WorkflowEventType.STEP_STARTED
|
||||
step_id: str
|
||||
input_data: Dict[str, Any]
|
||||
|
||||
|
||||
class StepCompletedEvent(WorkflowEvent):
|
||||
"""Step execution completed successfully."""
|
||||
event_type: WorkflowEventType = WorkflowEventType.STEP_COMPLETED
|
||||
step_id: str
|
||||
output_data: Dict[str, Any]
|
||||
duration_seconds: float
|
||||
|
||||
|
||||
class StepFailedEvent(WorkflowEvent):
|
||||
"""Step execution failed."""
|
||||
event_type: WorkflowEventType = WorkflowEventType.STEP_FAILED
|
||||
step_id: str
|
||||
error: str
|
||||
duration_seconds: float
|
||||
|
||||
|
||||
class EdgeActivatedEvent(WorkflowEvent):
|
||||
"""Edge between steps activated (data flowing)."""
|
||||
event_type: WorkflowEventType = WorkflowEventType.EDGE_ACTIVATED
|
||||
from_step: str
|
||||
to_step: str
|
||||
data: Dict[str, Any]
|
||||
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
Workflow runner implementation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, AsyncGenerator
|
||||
from datetime import datetime
|
||||
|
||||
from ._models import (
|
||||
WorkflowExecution, StepExecution, WorkflowStatus, StepStatus,
|
||||
WorkflowEvent, WorkflowStartedEvent, WorkflowCompletedEvent, WorkflowFailedEvent,
|
||||
StepStartedEvent, StepCompletedEvent, StepFailedEvent, EdgeActivatedEvent
|
||||
)
|
||||
from ._workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunner:
|
||||
"""Executes workflows with support for parallel execution."""
|
||||
|
||||
def __init__(self, max_concurrent_steps: int = 5):
|
||||
"""Initialize the runner.
|
||||
|
||||
Args:
|
||||
max_concurrent_steps: Maximum number of steps to run concurrently
|
||||
"""
|
||||
self.max_concurrent_steps = max_concurrent_steps
|
||||
self._execution_semaphore = asyncio.Semaphore(max_concurrent_steps)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
initial_input: Optional[Dict[str, Any]] = None
|
||||
) -> WorkflowExecution:
|
||||
"""Run a complete workflow and return the final result.
|
||||
|
||||
This is a convenience method that consumes the stream and returns
|
||||
only the final WorkflowExecution result.
|
||||
|
||||
Args:
|
||||
workflow: Workflow to execute
|
||||
initial_input: Initial input data for the start step
|
||||
|
||||
Returns:
|
||||
Final workflow execution result
|
||||
"""
|
||||
final_execution = None
|
||||
async for event in self.run_stream(workflow, initial_input):
|
||||
if event.event_type == "workflow_completed":
|
||||
final_execution = getattr(event, 'execution', None)
|
||||
elif event.event_type == "workflow_failed":
|
||||
execution = getattr(event, 'execution', None)
|
||||
if execution:
|
||||
final_execution = execution
|
||||
# Re-raise the error for backward compatibility
|
||||
error = getattr(event, 'error', 'Unknown workflow error')
|
||||
raise RuntimeError(error)
|
||||
|
||||
if final_execution is None:
|
||||
raise RuntimeError("Workflow completed but no final execution received")
|
||||
|
||||
return final_execution
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
initial_input: Optional[Dict[str, Any]] = None
|
||||
) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Run a workflow and yield real-time events.
|
||||
|
||||
Args:
|
||||
workflow: Workflow to execute
|
||||
initial_input: Initial input data for the start step
|
||||
|
||||
Yields:
|
||||
WorkflowEvent: Real-time workflow events
|
||||
|
||||
Raises:
|
||||
Exception: If workflow validation fails or execution encounters errors
|
||||
"""
|
||||
logger.info(f"Starting workflow execution: {workflow.id}")
|
||||
|
||||
# Emit workflow started event
|
||||
yield WorkflowStartedEvent(
|
||||
timestamp=datetime.now(),
|
||||
workflow_id=workflow.id,
|
||||
initial_input=initial_input or {}
|
||||
)
|
||||
|
||||
# Validate workflow
|
||||
validation = workflow.validate_workflow()
|
||||
if not validation.is_valid:
|
||||
error_msg = f"Workflow validation failed: {validation.errors}"
|
||||
logger.error(error_msg)
|
||||
yield WorkflowFailedEvent(
|
||||
timestamp=datetime.now(),
|
||||
workflow_id=workflow.id,
|
||||
error=error_msg
|
||||
)
|
||||
return
|
||||
|
||||
# Validate initial input matches start step's input type
|
||||
if initial_input and workflow.start_step_id:
|
||||
start_step = workflow.steps.get(workflow.start_step_id)
|
||||
if start_step:
|
||||
try:
|
||||
# Try to validate initial input against start step's input type
|
||||
start_step.input_type(**initial_input)
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"Initial input validation failed: Input does not match start step '{workflow.start_step_id}' "
|
||||
f"input type {start_step.input_type.__name__}: {str(e)}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
yield WorkflowFailedEvent(
|
||||
timestamp=datetime.now(),
|
||||
workflow_id=workflow.id,
|
||||
error=error_msg
|
||||
)
|
||||
return
|
||||
|
||||
# Create execution record
|
||||
execution = WorkflowExecution(
|
||||
workflow_id=workflow.id,
|
||||
status=WorkflowStatus.RUNNING,
|
||||
start_time=datetime.now(),
|
||||
state=workflow.initial_state.copy()
|
||||
)
|
||||
|
||||
try:
|
||||
# Add initial input to state if provided
|
||||
if initial_input:
|
||||
execution.state.update(initial_input)
|
||||
|
||||
# Execute the workflow with streaming events
|
||||
async for event in self._execute_workflow_stream(workflow, execution, initial_input or {}):
|
||||
yield event
|
||||
|
||||
# Check final status and emit completion event
|
||||
if all(
|
||||
step_exec.status == StepStatus.COMPLETED
|
||||
for step_exec in execution.step_executions.values()
|
||||
):
|
||||
execution.status = WorkflowStatus.COMPLETED
|
||||
execution.end_time = datetime.now()
|
||||
logger.info(f"Workflow {workflow.id} completed successfully")
|
||||
|
||||
yield WorkflowCompletedEvent(
|
||||
timestamp=datetime.now(),
|
||||
workflow_id=workflow.id,
|
||||
execution=execution
|
||||
)
|
||||
else:
|
||||
execution.status = WorkflowStatus.FAILED
|
||||
execution.end_time = datetime.now()
|
||||
error_msg = f"Workflow {workflow.id} failed"
|
||||
logger.error(error_msg)
|
||||
|
||||
yield WorkflowFailedEvent(
|
||||
timestamp=datetime.now(),
|
||||
workflow_id=workflow.id,
|
||||
error=error_msg,
|
||||
execution=execution
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution.status = WorkflowStatus.FAILED
|
||||
execution.error = str(e)
|
||||
execution.end_time = datetime.now()
|
||||
logger.error(f"Workflow {workflow.id} failed with error: {e}")
|
||||
|
||||
yield WorkflowFailedEvent(
|
||||
timestamp=datetime.now(),
|
||||
workflow_id=workflow.id,
|
||||
error=str(e),
|
||||
execution=execution
|
||||
)
|
||||
|
||||
async def _execute_workflow_stream(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
execution: WorkflowExecution,
|
||||
initial_input: Dict[str, Any]
|
||||
) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Execute the workflow steps and yield events.
|
||||
|
||||
Args:
|
||||
workflow: Workflow to execute
|
||||
execution: Execution context
|
||||
initial_input: Initial input data
|
||||
|
||||
Yields:
|
||||
WorkflowEvent: Step execution events
|
||||
"""
|
||||
completed_steps = set()
|
||||
running_tasks = {}
|
||||
|
||||
while len(completed_steps) < len(workflow.steps):
|
||||
# Get steps ready to run
|
||||
ready_steps = workflow.get_ready_steps(execution)
|
||||
ready_steps = [s for s in ready_steps if s not in completed_steps and s not in running_tasks]
|
||||
|
||||
if not ready_steps and not running_tasks:
|
||||
# No ready steps and nothing running - check if we're stuck
|
||||
remaining_steps = set(workflow.steps.keys()) - completed_steps
|
||||
if remaining_steps:
|
||||
error_msg = f"Workflow stuck: remaining steps {remaining_steps} cannot be executed"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
break
|
||||
|
||||
# Start new tasks for ready steps
|
||||
for step_id in ready_steps:
|
||||
if len(running_tasks) >= self.max_concurrent_steps:
|
||||
break
|
||||
|
||||
step = workflow.steps[step_id]
|
||||
input_data = self._prepare_step_input(step_id, workflow, execution, initial_input)
|
||||
|
||||
# Create step execution record
|
||||
step_execution = StepExecution(
|
||||
step_id=step_id,
|
||||
status=StepStatus.RUNNING,
|
||||
start_time=datetime.now(),
|
||||
input_data=input_data
|
||||
)
|
||||
execution.step_executions[step_id] = step_execution
|
||||
|
||||
# Emit step started event
|
||||
yield StepStartedEvent(
|
||||
timestamp=datetime.now(),
|
||||
workflow_id=workflow.id,
|
||||
step_id=step_id,
|
||||
input_data=input_data
|
||||
)
|
||||
|
||||
# Start the step task
|
||||
task = asyncio.create_task(self._run_step_with_semaphore(step, input_data, execution.state))
|
||||
running_tasks[step_id] = task
|
||||
|
||||
logger.info(f"Started step {step_id} in workflow {workflow.id}")
|
||||
|
||||
# Wait for at least one task to complete
|
||||
if running_tasks:
|
||||
done, pending = await asyncio.wait(
|
||||
running_tasks.values(),
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
# Process completed tasks
|
||||
for task in done:
|
||||
step_id = None
|
||||
for sid, t in running_tasks.items():
|
||||
if t == task:
|
||||
step_id = sid
|
||||
break
|
||||
|
||||
if step_id:
|
||||
step_execution = execution.step_executions[step_id]
|
||||
|
||||
try:
|
||||
result = await task
|
||||
step_execution.status = StepStatus.COMPLETED
|
||||
step_execution.output_data = result
|
||||
step_execution.end_time = datetime.now()
|
||||
|
||||
# Calculate duration
|
||||
duration = 0.0
|
||||
if step_execution.end_time and step_execution.start_time:
|
||||
duration = (step_execution.end_time - step_execution.start_time).total_seconds()
|
||||
|
||||
# Update workflow state with step output
|
||||
execution.state[f"{step_id}_output"] = result
|
||||
|
||||
completed_steps.add(step_id)
|
||||
logger.info(f"Step {step_id} completed successfully")
|
||||
|
||||
# Emit step completed event
|
||||
yield StepCompletedEvent(
|
||||
timestamp=datetime.now(),
|
||||
workflow_id=workflow.id,
|
||||
step_id=step_id,
|
||||
output_data=result,
|
||||
duration_seconds=duration
|
||||
)
|
||||
|
||||
# Emit edge activation events for next steps
|
||||
for edge in workflow.edges:
|
||||
if edge.from_step == step_id:
|
||||
yield EdgeActivatedEvent(
|
||||
timestamp=datetime.now(),
|
||||
workflow_id=workflow.id,
|
||||
from_step=step_id,
|
||||
to_step=edge.to_step,
|
||||
data=result
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
step_execution.status = StepStatus.FAILED
|
||||
step_execution.error = str(e)
|
||||
step_execution.end_time = datetime.now()
|
||||
|
||||
# Calculate duration
|
||||
duration = (step_execution.end_time - step_execution.start_time).total_seconds()
|
||||
|
||||
logger.error(f"Step {step_id} failed: {e}")
|
||||
|
||||
# Emit step failed event
|
||||
yield StepFailedEvent(
|
||||
timestamp=datetime.now(),
|
||||
workflow_id=workflow.id,
|
||||
step_id=step_id,
|
||||
error=str(e),
|
||||
duration_seconds=duration
|
||||
)
|
||||
|
||||
# For now, fail the entire workflow if any step fails
|
||||
# In the future, we could add error handling strategies
|
||||
raise
|
||||
|
||||
finally:
|
||||
del running_tasks[step_id]
|
||||
|
||||
# Check if we've reached an end step
|
||||
if any(step_id in completed_steps for step_id in workflow.end_step_ids):
|
||||
logger.info(f"Reached end step in workflow {workflow.id}")
|
||||
break
|
||||
|
||||
async def _run_step_with_semaphore(
|
||||
self,
|
||||
step,
|
||||
input_data: Dict[str, Any],
|
||||
workflow_state: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Run a step with concurrency control.
|
||||
|
||||
Args:
|
||||
step: Step to execute
|
||||
input_data: Input data for the step
|
||||
workflow_state: Current workflow state
|
||||
|
||||
Returns:
|
||||
Step output data
|
||||
"""
|
||||
async with self._execution_semaphore:
|
||||
from ._models import Context
|
||||
|
||||
# Create typed context that directly references workflow_state
|
||||
# This ensures modifications are persistent across steps
|
||||
typed_context = Context.from_state_ref(workflow_state)
|
||||
|
||||
# Convert to dict for step.run() compatibility, but context modifications
|
||||
# will still affect the original workflow_state since it's the same dict reference
|
||||
context = typed_context.to_dict()
|
||||
return await step.run(input_data, context)
|
||||
|
||||
def _prepare_step_input(
|
||||
self,
|
||||
step_id: str,
|
||||
workflow: Workflow,
|
||||
execution: WorkflowExecution,
|
||||
initial_input: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare input data for a step using direct type forwarding.
|
||||
|
||||
Args:
|
||||
step_id: Step to prepare input for
|
||||
workflow: Workflow being executed
|
||||
execution: Current execution state
|
||||
initial_input: Initial workflow input
|
||||
|
||||
Returns:
|
||||
Input data for the step
|
||||
"""
|
||||
# Start with initial input for the start step
|
||||
if step_id == workflow.start_step_id:
|
||||
return initial_input.copy()
|
||||
|
||||
# For other steps, use direct output forwarding from dependencies
|
||||
dependencies = workflow.get_step_dependencies(step_id)
|
||||
|
||||
if not dependencies:
|
||||
# No dependencies, use initial input
|
||||
return initial_input.copy()
|
||||
|
||||
# For sequential workflows: use the most recent dependency's output directly
|
||||
# For parallel/fan-in: this logic would need to be more sophisticated
|
||||
latest_dependency = dependencies[-1] # Most recent dependency
|
||||
dep_execution = execution.step_executions.get(latest_dependency)
|
||||
|
||||
if dep_execution and dep_execution.output_data:
|
||||
# Direct forwarding: previous step's output becomes this step's input
|
||||
return dep_execution.output_data.copy()
|
||||
else:
|
||||
# Fallback to initial input if dependency output not available
|
||||
logger.warning(f"No output available from dependency {latest_dependency} for step {step_id}, using initial input")
|
||||
return initial_input.copy()
|
||||
|
||||
async def run_step(
|
||||
self,
|
||||
step,
|
||||
input_data: Dict[str, Any],
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run a single step independently.
|
||||
|
||||
Args:
|
||||
step: Step to execute
|
||||
input_data: Input data
|
||||
context: Additional context
|
||||
|
||||
Returns:
|
||||
Step output data
|
||||
"""
|
||||
context = context or {}
|
||||
return await step.run(input_data, context)
|
||||
|
||||
def get_execution_status(self, execution: WorkflowExecution) -> Dict[str, Any]:
|
||||
"""Get detailed status of a workflow execution.
|
||||
|
||||
Args:
|
||||
execution: Workflow execution to analyze
|
||||
|
||||
Returns:
|
||||
Status information
|
||||
"""
|
||||
total_steps = len(execution.step_executions)
|
||||
completed_steps = sum(
|
||||
1 for step_exec in execution.step_executions.values()
|
||||
if step_exec.status == StepStatus.COMPLETED
|
||||
)
|
||||
failed_steps = sum(
|
||||
1 for step_exec in execution.step_executions.values()
|
||||
if step_exec.status == StepStatus.FAILED
|
||||
)
|
||||
running_steps = sum(
|
||||
1 for step_exec in execution.step_executions.values()
|
||||
if step_exec.status == StepStatus.RUNNING
|
||||
)
|
||||
|
||||
duration = None
|
||||
if execution.start_time and execution.end_time:
|
||||
duration = (execution.end_time - execution.start_time).total_seconds()
|
||||
|
||||
return {
|
||||
"execution_id": execution.id,
|
||||
"workflow_id": execution.workflow_id,
|
||||
"status": execution.status.value,
|
||||
"progress": {
|
||||
"total_steps": total_steps,
|
||||
"completed_steps": completed_steps,
|
||||
"failed_steps": failed_steps,
|
||||
"running_steps": running_steps,
|
||||
"percentage": (completed_steps / total_steps * 100) if total_steps > 0 else 0
|
||||
},
|
||||
"timing": {
|
||||
"start_time": execution.start_time,
|
||||
"end_time": execution.end_time,
|
||||
"duration_seconds": duration
|
||||
},
|
||||
"error": execution.error
|
||||
}
|
||||
@@ -0,0 +1,474 @@
|
||||
"""
|
||||
Workflow implementation for the process system.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import Self
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from autogen_core import Component, ComponentModel, ComponentBase
|
||||
|
||||
from ._models import (
|
||||
Edge, WorkflowMetadata, WorkflowValidationResult,
|
||||
WorkflowStatus, StepExecution, WorkflowExecution
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..steps._step import BaseStep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type variable for return type chaining
|
||||
WorkflowT = TypeVar("WorkflowT", bound="BaseWorkflow")
|
||||
|
||||
|
||||
class WorkflowConfig(BaseModel):
|
||||
"""Configuration for workflow serialization."""
|
||||
metadata: WorkflowMetadata
|
||||
steps: List[ComponentModel] = Field(default_factory=list, description="Serialized step component models")
|
||||
edges: List[Edge] = Field(default_factory=list)
|
||||
initial_state: Dict[str, Any] = Field(default_factory=dict)
|
||||
start_step_id: Optional[str] = None
|
||||
end_step_ids: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BaseWorkflow(ComponentBase[BaseModel]):
|
||||
"""Base class for workflows with core logic."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metadata: WorkflowMetadata,
|
||||
initial_state: Optional[Dict[str, Any]] = None,
|
||||
workflow_id: Optional[str] = None
|
||||
):
|
||||
"""Initialize the workflow.
|
||||
|
||||
Args:
|
||||
metadata: Workflow metadata
|
||||
initial_state: Initial workflow state
|
||||
workflow_id: Optional workflow ID
|
||||
"""
|
||||
self.id = workflow_id or str(uuid.uuid4())
|
||||
self.metadata = metadata
|
||||
self.steps: Dict[str, "BaseStep"] = {}
|
||||
self.edges: List[Edge] = []
|
||||
self.initial_state = initial_state or {}
|
||||
self.start_step_id: Optional[str] = None
|
||||
self.end_step_ids: List[str] = []
|
||||
|
||||
def add_step(self, step: "BaseStep") -> Self:
|
||||
"""Add a step to the workflow.
|
||||
|
||||
Args:
|
||||
step: Step to add
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
self.steps[step.step_id] = step
|
||||
logger.debug(f"Added step {step.step_id} to workflow {self.id}")
|
||||
return self
|
||||
|
||||
def add_edge(self, from_step: str, to_step: str, condition: Optional[Dict[str, Any]] = None) -> Self:
|
||||
"""Add an edge between steps.
|
||||
|
||||
Args:
|
||||
from_step: Source step ID
|
||||
to_step: Target step ID
|
||||
condition: Optional condition for the edge
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
from ._models import EdgeCondition
|
||||
|
||||
edge_condition = EdgeCondition(**condition) if condition else EdgeCondition()
|
||||
edge = Edge(from_step=from_step, to_step=to_step, condition=edge_condition)
|
||||
self.edges.append(edge)
|
||||
logger.debug(f"Added edge {from_step} -> {to_step} to workflow {self.id}")
|
||||
return self
|
||||
|
||||
def set_start_step(self, step_id: str) -> Self:
|
||||
"""Set the starting step for the workflow.
|
||||
|
||||
Args:
|
||||
step_id: ID of the step to start with
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
if step_id not in self.steps:
|
||||
raise ValueError(f"Step {step_id} not found in workflow")
|
||||
self.start_step_id = step_id
|
||||
logger.debug(f"Set start step to {step_id} for workflow {self.id}")
|
||||
return self
|
||||
|
||||
def add_end_step(self, step_id: str) -> Self:
|
||||
"""Add an end step to the workflow.
|
||||
|
||||
Args:
|
||||
step_id: ID of the step that can end the workflow
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
if step_id not in self.steps:
|
||||
raise ValueError(f"Step {step_id} not found in workflow")
|
||||
if step_id not in self.end_step_ids:
|
||||
self.end_step_ids.append(step_id)
|
||||
logger.debug(f"Added end step {step_id} to workflow {self.id}")
|
||||
return self
|
||||
|
||||
def get_step_dependencies(self, step_id: str) -> List[str]:
|
||||
"""Get all steps that must complete before this step can run.
|
||||
|
||||
Args:
|
||||
step_id: Step to get dependencies for
|
||||
|
||||
Returns:
|
||||
List of step IDs that this step depends on
|
||||
"""
|
||||
return [edge.from_step for edge in self.edges if edge.to_step == step_id]
|
||||
|
||||
def get_step_dependents(self, step_id: str) -> List[str]:
|
||||
"""Get all steps that depend on this step.
|
||||
|
||||
Args:
|
||||
step_id: Step to get dependents for
|
||||
|
||||
Returns:
|
||||
List of step IDs that depend on this step
|
||||
"""
|
||||
return [edge.to_step for edge in self.edges if edge.from_step == step_id]
|
||||
|
||||
def get_ready_steps(self, execution: WorkflowExecution) -> List[str]:
|
||||
"""Get steps that are ready to run (all dependencies completed).
|
||||
|
||||
Args:
|
||||
execution: Current workflow execution state
|
||||
|
||||
Returns:
|
||||
List of step IDs ready to run
|
||||
"""
|
||||
ready_steps = []
|
||||
|
||||
for step_id in self.steps:
|
||||
step_exec = execution.step_executions.get(step_id)
|
||||
|
||||
# Skip if already running, completed, or failed
|
||||
if step_exec and step_exec.status.value in ["running", "completed", "failed"]:
|
||||
continue
|
||||
|
||||
# Check if all dependencies are completed
|
||||
dependencies = self.get_step_dependencies(step_id)
|
||||
if not dependencies and not step_exec:
|
||||
# No dependencies and not started - ready if it's the start step
|
||||
if step_id == self.start_step_id:
|
||||
ready_steps.append(step_id)
|
||||
elif dependencies:
|
||||
# Check if all dependencies are completed
|
||||
all_deps_complete = True
|
||||
for dep_id in dependencies:
|
||||
dep_exec = execution.step_executions.get(dep_id)
|
||||
if not dep_exec or dep_exec.status.value != "completed":
|
||||
all_deps_complete = False
|
||||
break
|
||||
|
||||
if all_deps_complete:
|
||||
# Also check edge conditions
|
||||
for edge in self.edges:
|
||||
if edge.to_step == step_id:
|
||||
if self._evaluate_edge_condition(edge, execution):
|
||||
ready_steps.append(step_id)
|
||||
break
|
||||
|
||||
return ready_steps
|
||||
|
||||
def _evaluate_edge_condition(self, edge: Edge, execution: WorkflowExecution) -> bool:
|
||||
"""Evaluate if an edge condition is met.
|
||||
|
||||
Args:
|
||||
edge: Edge to evaluate
|
||||
execution: Current execution state
|
||||
|
||||
Returns:
|
||||
True if condition is met
|
||||
"""
|
||||
condition = edge.condition
|
||||
|
||||
if condition.type == "always":
|
||||
return True
|
||||
|
||||
if condition.type == "output_based":
|
||||
from_step_exec = execution.step_executions.get(edge.from_step)
|
||||
if not from_step_exec or not from_step_exec.output_data:
|
||||
return False
|
||||
|
||||
# Simple field-based condition evaluation
|
||||
if condition.field and condition.operator and condition.value is not None:
|
||||
field_value = from_step_exec.output_data.get(condition.field)
|
||||
return self._compare_values(field_value, condition.operator, condition.value)
|
||||
|
||||
if condition.type == "state_based":
|
||||
if condition.field and condition.operator and condition.value is not None:
|
||||
field_value = execution.state.get(condition.field)
|
||||
return self._compare_values(field_value, condition.operator, condition.value)
|
||||
|
||||
# For expression-based conditions, we'd eval the expression here
|
||||
# For now, default to True for unsupported conditions
|
||||
return True
|
||||
|
||||
def _compare_values(self, left: Any, operator: str, right: Any) -> bool:
|
||||
"""Compare two values using the given operator.
|
||||
|
||||
Args:
|
||||
left: Left operand
|
||||
operator: Comparison operator
|
||||
right: Right operand
|
||||
|
||||
Returns:
|
||||
Comparison result
|
||||
"""
|
||||
try:
|
||||
if operator == "==":
|
||||
return left == right
|
||||
elif operator == "!=":
|
||||
return left != right
|
||||
elif operator == ">":
|
||||
return left > right
|
||||
elif operator == "<":
|
||||
return left < right
|
||||
elif operator == ">=":
|
||||
return left >= right
|
||||
elif operator == "<=":
|
||||
return left <= right
|
||||
elif operator == "in":
|
||||
return left in right
|
||||
elif operator == "not_in":
|
||||
return left not in right
|
||||
else:
|
||||
logger.warning(f"Unknown operator: {operator}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error comparing values: {e}")
|
||||
return False
|
||||
|
||||
def validate_workflow(self) -> WorkflowValidationResult:
|
||||
"""Validate the workflow structure.
|
||||
|
||||
Returns:
|
||||
Validation result with errors and warnings
|
||||
"""
|
||||
result = WorkflowValidationResult(is_valid=True)
|
||||
|
||||
# Check if workflow has steps
|
||||
if not self.steps:
|
||||
result.errors.append("Workflow has no steps")
|
||||
result.is_valid = False
|
||||
|
||||
# Check if start step is set and exists
|
||||
if not self.start_step_id:
|
||||
result.errors.append("No start step specified")
|
||||
result.is_valid = False
|
||||
elif self.start_step_id not in self.steps:
|
||||
result.errors.append(f"Start step {self.start_step_id} not found in workflow")
|
||||
result.is_valid = False
|
||||
|
||||
# Check if end steps exist
|
||||
if not self.end_step_ids:
|
||||
result.warnings.append("No end steps specified - workflow may run indefinitely")
|
||||
else:
|
||||
for end_step_id in self.end_step_ids:
|
||||
if end_step_id not in self.steps:
|
||||
result.errors.append(f"End step {end_step_id} not found in workflow")
|
||||
result.is_valid = False
|
||||
|
||||
# Check if all edge references exist
|
||||
for edge in self.edges:
|
||||
if edge.from_step not in self.steps:
|
||||
result.errors.append(f"Edge references non-existent step: {edge.from_step}")
|
||||
result.is_valid = False
|
||||
if edge.to_step not in self.steps:
|
||||
result.errors.append(f"Edge references non-existent step: {edge.to_step}")
|
||||
result.is_valid = False
|
||||
|
||||
# Check for cycles using DFS
|
||||
result.has_cycles, cycle_info = self._detect_cycles()
|
||||
if result.has_cycles:
|
||||
result.errors.append(f"Workflow contains cycles: {cycle_info}")
|
||||
result.is_valid = False
|
||||
|
||||
# Check for unreachable steps
|
||||
result.unreachable_steps = self._find_unreachable_steps()
|
||||
if result.unreachable_steps:
|
||||
result.warnings.append(f"Unreachable steps found: {result.unreachable_steps}")
|
||||
|
||||
# Check for type compatibility between connected steps
|
||||
for edge in self.edges:
|
||||
if edge.from_step in self.steps and edge.to_step in self.steps:
|
||||
from_step = self.steps[edge.from_step]
|
||||
to_step = self.steps[edge.to_step]
|
||||
|
||||
# Check type compatibility using schema-based comparison
|
||||
# This is more robust than direct type comparison, especially for dynamically created types
|
||||
types_compatible = False
|
||||
|
||||
if hasattr(from_step.output_type, 'model_json_schema') and hasattr(to_step.input_type, 'model_json_schema'):
|
||||
from_schema = from_step.output_type.model_json_schema()
|
||||
to_schema = to_step.input_type.model_json_schema()
|
||||
|
||||
# Consider types compatible if they have the same name and schema
|
||||
if (from_step.output_type.__name__ == to_step.input_type.__name__ and
|
||||
from_schema == to_schema):
|
||||
types_compatible = True
|
||||
logger.debug(f"Types compatible by schema: {edge.from_step} -> {edge.to_step}")
|
||||
else:
|
||||
logger.debug(f"Schema mismatch for edge {edge.from_step} -> {edge.to_step}")
|
||||
logger.debug(f" Output schema: {from_schema}")
|
||||
logger.debug(f" Input schema: {to_schema}")
|
||||
else:
|
||||
# Fallback to direct type comparison for non-Pydantic types
|
||||
types_compatible = from_step.output_type == to_step.input_type
|
||||
logger.debug(f"Using direct type comparison: {types_compatible}")
|
||||
|
||||
if not types_compatible:
|
||||
error_msg = (
|
||||
f"Type mismatch: Step '{edge.from_step}' outputs {from_step.output_type.__name__} "
|
||||
f"but step '{edge.to_step}' expects {to_step.input_type.__name__}"
|
||||
)
|
||||
result.errors.append(error_msg)
|
||||
result.is_valid = False
|
||||
|
||||
return result
|
||||
|
||||
def _detect_cycles(self) -> tuple[bool, Optional[str]]:
|
||||
"""Detect cycles in the workflow graph.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_cycles, cycle_description)
|
||||
"""
|
||||
if not self.start_step_id:
|
||||
return False, None
|
||||
|
||||
visited = set()
|
||||
rec_stack = set()
|
||||
|
||||
def dfs(step_id: str, path: List[str]) -> tuple[bool, Optional[str]]:
|
||||
if step_id in rec_stack:
|
||||
cycle_start = path.index(step_id)
|
||||
cycle = " -> ".join(path[cycle_start:] + [step_id])
|
||||
return True, cycle
|
||||
|
||||
if step_id in visited:
|
||||
return False, None
|
||||
|
||||
visited.add(step_id)
|
||||
rec_stack.add(step_id)
|
||||
|
||||
# Get all steps this step leads to
|
||||
next_steps = self.get_step_dependents(step_id)
|
||||
for next_step in next_steps:
|
||||
has_cycle, cycle_info = dfs(next_step, path + [step_id])
|
||||
if has_cycle:
|
||||
return True, cycle_info
|
||||
|
||||
rec_stack.remove(step_id)
|
||||
return False, None
|
||||
|
||||
return dfs(self.start_step_id, [])
|
||||
|
||||
def _find_unreachable_steps(self) -> List[str]:
|
||||
"""Find steps that cannot be reached from the start step.
|
||||
|
||||
Returns:
|
||||
List of unreachable step IDs
|
||||
"""
|
||||
if not self.start_step_id:
|
||||
return list(self.steps.keys())
|
||||
|
||||
reachable = set()
|
||||
to_visit = [self.start_step_id]
|
||||
|
||||
while to_visit:
|
||||
current = to_visit.pop()
|
||||
if current in reachable:
|
||||
continue
|
||||
|
||||
reachable.add(current)
|
||||
next_steps = self.get_step_dependents(current)
|
||||
to_visit.extend(next_steps)
|
||||
|
||||
return [step_id for step_id in self.steps if step_id not in reachable]
|
||||
|
||||
def get_execution_plan(self) -> Dict[str, Any]:
|
||||
"""Get a visual representation of the workflow execution plan.
|
||||
|
||||
Returns:
|
||||
Dictionary with workflow structure information
|
||||
"""
|
||||
return {
|
||||
"workflow_id": self.id,
|
||||
"metadata": self.metadata.model_dump(),
|
||||
"steps": {
|
||||
step_id: step.get_schema()
|
||||
for step_id, step in self.steps.items()
|
||||
},
|
||||
"edges": [edge.model_dump() for edge in self.edges],
|
||||
"start_step": self.start_step_id,
|
||||
"end_steps": self.end_step_ids,
|
||||
"validation": self.validate_workflow().model_dump()
|
||||
}
|
||||
|
||||
|
||||
class Workflow(BaseWorkflow, Component[WorkflowConfig]):
|
||||
"""Concrete workflow implementation with component serialization support."""
|
||||
|
||||
component_config_schema = WorkflowConfig
|
||||
component_type = "workflow"
|
||||
component_provider_override = "autogenstudio.workflow.core.Workflow"
|
||||
|
||||
def _to_config(self) -> WorkflowConfig:
|
||||
"""Convert workflow to configuration for serialization."""
|
||||
step_configs = [step.dump_component() for step in self.steps.values()]
|
||||
|
||||
return WorkflowConfig(
|
||||
metadata=self.metadata,
|
||||
steps=step_configs,
|
||||
edges=self.edges,
|
||||
initial_state=self.initial_state,
|
||||
start_step_id=self.start_step_id,
|
||||
end_step_ids=self.end_step_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: WorkflowConfig) -> "Workflow":
|
||||
"""Create workflow from configuration.
|
||||
|
||||
Args:
|
||||
config: Workflow configuration
|
||||
"""
|
||||
from ..steps._step import BaseStep
|
||||
|
||||
workflow = cls(
|
||||
metadata=config.metadata,
|
||||
initial_state=config.initial_state
|
||||
)
|
||||
|
||||
# Deserialize and add steps
|
||||
for step_model in config.steps:
|
||||
step = BaseStep.load_component(step_model)
|
||||
workflow.add_step(step)
|
||||
|
||||
# Add edges
|
||||
for edge in config.edges:
|
||||
workflow.edges.append(edge)
|
||||
|
||||
# Set start and end steps
|
||||
workflow.start_step_id = config.start_step_id
|
||||
workflow.end_step_ids = config.end_step_ids
|
||||
|
||||
return workflow
|
||||
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Fan-out/Fan-in workflow example: broadcast -> (double, square, add_ten) -> sum
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogenstudio.workflow import Workflow, WorkflowRunner, WorkflowMetadata, StepMetadata
|
||||
from autogenstudio.workflow.steps import FunctionStep
|
||||
from autogenstudio.workflow.core._models import Context
|
||||
|
||||
|
||||
# Define data models
|
||||
class NumberInput(BaseModel):
|
||||
value: int
|
||||
|
||||
|
||||
class NumberOutput(BaseModel):
|
||||
result: int
|
||||
|
||||
|
||||
class SumOutput(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
# Define step functions
|
||||
async def broadcast_input(input_data: NumberInput, context: Context) -> NumberOutput:
|
||||
"""Pass through the input to multiple downstream steps."""
|
||||
print(f"Broadcast step - input_data: {input_data}, context state keys: {list(context.state.keys())}")
|
||||
result = input_data.value
|
||||
|
||||
print(f"Broadcast step - result: {result} (will be forwarded to parallel steps)")
|
||||
return NumberOutput(result=result)
|
||||
|
||||
|
||||
async def double_number(input_data: NumberOutput, context: Context) -> NumberOutput:
|
||||
"""Double a number."""
|
||||
print(f"Double step - input_data: {input_data}, context state keys: {list(context.state.keys())}")
|
||||
await asyncio.sleep(0.1) # Simulate some work
|
||||
|
||||
# Use the direct input from broadcast step
|
||||
value = input_data.result
|
||||
|
||||
result = value * 2
|
||||
print(f"Double step - using value: {value}, result: {result}")
|
||||
|
||||
# Store result in context for fan-in step
|
||||
context.set('double_result', result)
|
||||
|
||||
return NumberOutput(result=result)
|
||||
|
||||
|
||||
async def square_number(input_data: NumberOutput, context: Context) -> NumberOutput:
|
||||
"""Square a number."""
|
||||
print(f"Square step - input_data: {input_data}, context state keys: {list(context.state.keys())}")
|
||||
await asyncio.sleep(0.1) # Simulate some work
|
||||
|
||||
# Use the direct input from broadcast step
|
||||
value = input_data.result
|
||||
|
||||
result = value ** 2
|
||||
print(f"Square step - using value: {value}, result: {result}")
|
||||
|
||||
# Store result in context for fan-in step
|
||||
context.set('square_result', result)
|
||||
|
||||
return NumberOutput(result=result)
|
||||
|
||||
|
||||
async def add_ten(input_data: NumberOutput, context: Context) -> NumberOutput:
|
||||
"""Add 10 to a number."""
|
||||
print(f"Add_ten step - input_data: {input_data}, context state keys: {list(context.state.keys())}")
|
||||
await asyncio.sleep(0.1) # Simulate some work
|
||||
|
||||
# Use the direct input from broadcast step
|
||||
value = input_data.result
|
||||
|
||||
result = value + 10
|
||||
print(f"Add_ten step - using value: {value}, result: {result}")
|
||||
|
||||
# Store result in context for fan-in step
|
||||
context.set('add_ten_result', result)
|
||||
|
||||
return NumberOutput(result=result)
|
||||
|
||||
|
||||
async def sum_results(input_data: NumberOutput, context: Context) -> SumOutput:
|
||||
"""Sum multiple numbers from parallel steps using shared context."""
|
||||
print(f"Sum step - input_data: {input_data}, context state keys: {list(context.state.keys())}")
|
||||
|
||||
# Collect results from shared context (cleaner than automatic output storage)
|
||||
double_result = context.get('double_result', 0)
|
||||
square_result = context.get('square_result', 0)
|
||||
add_ten_result = context.get('add_ten_result', 0)
|
||||
|
||||
results = [double_result, square_result, add_ten_result]
|
||||
total = sum(results)
|
||||
|
||||
print(f"Sum step - collected from context: double={double_result}, square={square_result}, add_ten={add_ten_result}")
|
||||
print(f"Sum step - total: {total}")
|
||||
|
||||
# Store final summary in context
|
||||
context.set('fan_in_summary', {
|
||||
'inputs': {'double': double_result, 'square': square_result, 'add_ten': add_ten_result},
|
||||
'total': total
|
||||
})
|
||||
|
||||
return SumOutput(total=total)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the fan-out/fan-in workflow example."""
|
||||
|
||||
print("=== Fan-out/Fan-in Workflow Example ===")
|
||||
print("Expected: 5 -> broadcast(5) -> parallel[double(10), square(25), add_ten(15)] -> sum(50)")
|
||||
|
||||
# Create steps
|
||||
broadcast_step = FunctionStep(
|
||||
step_id="broadcast",
|
||||
metadata=StepMetadata(name="Broadcast Input"),
|
||||
input_type=NumberInput,
|
||||
output_type=NumberOutput,
|
||||
func=broadcast_input
|
||||
)
|
||||
|
||||
double_step = FunctionStep(
|
||||
step_id="double",
|
||||
metadata=StepMetadata(name="Double Number"),
|
||||
input_type=NumberOutput, # Takes NumberOutput from broadcast
|
||||
output_type=NumberOutput,
|
||||
func=double_number
|
||||
)
|
||||
|
||||
square_step = FunctionStep(
|
||||
step_id="square",
|
||||
metadata=StepMetadata(name="Square Number"),
|
||||
input_type=NumberOutput, # Takes NumberOutput from broadcast
|
||||
output_type=NumberOutput,
|
||||
func=square_number
|
||||
)
|
||||
|
||||
add_ten_step = FunctionStep(
|
||||
step_id="add_ten",
|
||||
metadata=StepMetadata(name="Add Ten"),
|
||||
input_type=NumberOutput, # Takes NumberOutput from broadcast
|
||||
output_type=NumberOutput,
|
||||
func=add_ten
|
||||
)
|
||||
|
||||
sum_step = FunctionStep(
|
||||
step_id="sum",
|
||||
metadata=StepMetadata(name="Sum Results"),
|
||||
input_type=NumberOutput, # Takes NumberOutput from parallel steps
|
||||
output_type=SumOutput,
|
||||
func=sum_results
|
||||
)
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
metadata=WorkflowMetadata(name="Fan-out Fan-in Example")
|
||||
)
|
||||
|
||||
# Add steps
|
||||
workflow.add_step(broadcast_step)
|
||||
workflow.add_step(double_step)
|
||||
workflow.add_step(square_step)
|
||||
workflow.add_step(add_ten_step)
|
||||
workflow.add_step(sum_step)
|
||||
|
||||
# Set up the fan-out pattern: broadcast -> all three parallel operations
|
||||
workflow.set_start_step("broadcast")
|
||||
workflow.add_edge("broadcast", "double")
|
||||
workflow.add_edge("broadcast", "square")
|
||||
workflow.add_edge("broadcast", "add_ten")
|
||||
|
||||
# Set up the fan-in pattern: all three operations -> sum
|
||||
workflow.add_edge("double", "sum")
|
||||
workflow.add_edge("square", "sum")
|
||||
workflow.add_edge("add_ten", "sum")
|
||||
|
||||
# Set end step
|
||||
workflow.add_end_step("sum")
|
||||
|
||||
# Run workflow
|
||||
runner = WorkflowRunner(max_concurrent_steps=3)
|
||||
initial_input = {"value": 5}
|
||||
|
||||
print(f"\nRunning workflow with input: {initial_input}")
|
||||
execution = await runner.run(workflow, initial_input)
|
||||
|
||||
# Print results
|
||||
print("\n=== Results ===")
|
||||
for step_id, step_exec in execution.step_executions.items():
|
||||
print(f"{step_id}: {step_exec.output_data}")
|
||||
|
||||
print(f"\nFinal result: {execution.step_executions['sum'].output_data}")
|
||||
print(f"Expected: {{'total': 50}}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Simple parallel workflow example: parallel execution of independent steps
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogenstudio.workflow import Workflow, WorkflowRunner, WorkflowMetadata, StepMetadata
|
||||
from autogenstudio.workflow.steps import FunctionStep
|
||||
|
||||
|
||||
# Define data models
|
||||
class NumberInput(BaseModel):
|
||||
value: int
|
||||
|
||||
|
||||
class NumberOutput(BaseModel):
|
||||
result: int
|
||||
|
||||
|
||||
# Define step functions
|
||||
async def double_number(input_data: NumberInput, context: Dict[str, Any]) -> NumberOutput:
|
||||
"""Double a number."""
|
||||
print(f"Double step - input_data: {input_data}, context keys: {list(context.keys())}")
|
||||
await asyncio.sleep(0.2) # Simulate some work
|
||||
result = input_data.value * 2
|
||||
print(f"Double step - result: {result}")
|
||||
return NumberOutput(result=result)
|
||||
|
||||
|
||||
async def square_number(input_data: NumberInput, context: Dict[str, Any]) -> NumberOutput:
|
||||
"""Square a number."""
|
||||
print(f"Square step - input_data: {input_data}, context keys: {list(context.keys())}")
|
||||
await asyncio.sleep(0.1) # Simulate some work
|
||||
result = input_data.value ** 2
|
||||
print(f"Square step - result: {result}")
|
||||
return NumberOutput(result=result)
|
||||
|
||||
|
||||
async def add_ten(input_data: NumberInput, context: Dict[str, Any]) -> NumberOutput:
|
||||
"""Add 10 to a number."""
|
||||
print(f"Add_ten step - input_data: {input_data}, context keys: {list(context.keys())}")
|
||||
await asyncio.sleep(0.15) # Simulate some work
|
||||
result = input_data.value + 10
|
||||
print(f"Add_ten step - result: {result}")
|
||||
return NumberOutput(result=result)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the parallel workflow example."""
|
||||
|
||||
print("=== Simple Parallel Workflow Example ===")
|
||||
print("Expected: All steps run in parallel with the same input (7)")
|
||||
print(" - double: 7 * 2 = 14")
|
||||
print(" - square: 7 * 7 = 49")
|
||||
print(" - add_ten: 7 + 10 = 17")
|
||||
|
||||
# Create steps - all independent, no dependencies
|
||||
double_step = FunctionStep(
|
||||
step_id="double",
|
||||
metadata=StepMetadata(name="Double Number"),
|
||||
input_type=NumberInput,
|
||||
output_type=NumberOutput,
|
||||
func=double_number
|
||||
)
|
||||
|
||||
square_step = FunctionStep(
|
||||
step_id="square",
|
||||
metadata=StepMetadata(name="Square Number"),
|
||||
input_type=NumberInput,
|
||||
output_type=NumberOutput,
|
||||
func=square_number
|
||||
)
|
||||
|
||||
add_ten_step = FunctionStep(
|
||||
step_id="add_ten",
|
||||
metadata=StepMetadata(name="Add Ten"),
|
||||
input_type=NumberInput,
|
||||
output_type=NumberOutput,
|
||||
func=add_ten
|
||||
)
|
||||
|
||||
# Create workflow - each step can be a start step since they're independent
|
||||
workflow = Workflow(
|
||||
metadata=WorkflowMetadata(name="Parallel Example")
|
||||
)
|
||||
|
||||
workflow.add_step(double_step)
|
||||
workflow.add_step(square_step)
|
||||
workflow.add_step(add_ten_step)
|
||||
|
||||
# No edges - all steps are independent and can start simultaneously
|
||||
# We'll set one as start step but run all with the same input
|
||||
workflow.set_start_step("double")
|
||||
workflow.add_end_step("double")
|
||||
workflow.add_end_step("square")
|
||||
workflow.add_end_step("add_ten")
|
||||
|
||||
# Run workflow
|
||||
runner = WorkflowRunner(max_concurrent_steps=3)
|
||||
initial_input = {"value": 7}
|
||||
|
||||
print(f"\nRunning workflow with input: {initial_input}")
|
||||
|
||||
# For true parallel execution, we'd need to start all steps at once
|
||||
# This is a limitation of the current workflow design - it expects a single start step
|
||||
# Let's run each step individually to demonstrate parallel capability
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Create individual tasks
|
||||
tasks = []
|
||||
for step in [double_step, square_step, add_ten_step]:
|
||||
task = asyncio.create_task(step.run(initial_input, {}))
|
||||
tasks.append((step.step_id, task))
|
||||
|
||||
# Wait for all to complete
|
||||
results = {}
|
||||
for step_id, task in tasks:
|
||||
result = await task
|
||||
results[step_id] = result
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Print results
|
||||
print("\n=== Results ===")
|
||||
for step_id, result in results.items():
|
||||
print(f"{step_id}: {result}")
|
||||
|
||||
print(f"\nTotal execution time: {end_time - start_time:.3f} seconds")
|
||||
print("(Should be ~0.2s if truly parallel, not 0.45s if sequential)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Simple sequential workflow example: double -> square -> add_ten
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogenstudio.workflow import Workflow, WorkflowRunner, WorkflowMetadata, StepMetadata
|
||||
from autogenstudio.workflow.steps import FunctionStep
|
||||
from autogenstudio.workflow.core._models import Context
|
||||
|
||||
|
||||
# Define data models
|
||||
class NumberInput(BaseModel):
|
||||
value: int
|
||||
|
||||
|
||||
class NumberOutput(BaseModel):
|
||||
result: int
|
||||
|
||||
|
||||
# Define step functions
|
||||
async def double_number(input_data: NumberInput, context: Context) -> NumberOutput:
|
||||
"""Double a number and track operation in shared context."""
|
||||
print(f"Double step - input_data: {input_data}, context state keys: {list(context.state.keys())}")
|
||||
await asyncio.sleep(0.1) # Simulate some work
|
||||
|
||||
# Track the operation in shared context
|
||||
context.set('operations_performed', ['double'])
|
||||
context.set('original_input', input_data.value)
|
||||
|
||||
value = input_data.value
|
||||
result = value * 2
|
||||
print(f"Double step - using value: {value}, result: {result}")
|
||||
print(f"Double step - stored in context: operations={context.get('operations_performed')}")
|
||||
|
||||
return NumberOutput(result=result)
|
||||
|
||||
|
||||
async def square_number(input_data: NumberOutput, context: Context) -> NumberOutput:
|
||||
"""Square a number and update shared state tracking."""
|
||||
print(f"Square step - input_data: {input_data}, context state keys: {list(context.state.keys())}")
|
||||
await asyncio.sleep(0.1) # Simulate some work
|
||||
|
||||
# Read from shared context and update operation tracking
|
||||
operations = context.get('operations_performed', [])
|
||||
original_input = context.get('original_input', 'unknown')
|
||||
operations.append('square')
|
||||
context.set('operations_performed', operations)
|
||||
context.set('intermediate_results', {'after_double': input_data.result})
|
||||
|
||||
value = input_data.result
|
||||
result = value ** 2
|
||||
|
||||
print(f"Square step - using value: {value}, result: {result}")
|
||||
print(f"Square step - operations so far: {operations}, original input was: {original_input}")
|
||||
|
||||
return NumberOutput(result=result)
|
||||
|
||||
|
||||
async def add_ten(input_data: NumberOutput, context: Context) -> NumberOutput:
|
||||
"""Add 10 to a number and finalize shared state."""
|
||||
print(f"Add_ten step - input_data: {input_data}, context state keys: {list(context.state.keys())}")
|
||||
await asyncio.sleep(0.1) # Simulate some work
|
||||
|
||||
# Read complete operation history from shared context
|
||||
operations = context.get('operations_performed', [])
|
||||
original_input = context.get('original_input', 'unknown')
|
||||
intermediate_results = context.get('intermediate_results', {})
|
||||
|
||||
operations.append('add_ten')
|
||||
context.set('operations_performed', operations)
|
||||
|
||||
value = input_data.result
|
||||
result = value + 10
|
||||
|
||||
# Store final summary in context
|
||||
context.set('workflow_summary', {
|
||||
'original_input': original_input,
|
||||
'operations': operations,
|
||||
'intermediate_results': intermediate_results,
|
||||
'final_result': result
|
||||
})
|
||||
|
||||
print(f"Add_ten step - using value: {value}, result: {result}")
|
||||
print(f"Add_ten step - complete workflow: {original_input} → {operations} → {result}")
|
||||
|
||||
return NumberOutput(result=result)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the sequential workflow example."""
|
||||
|
||||
print("=== Simple Sequential Workflow Example ===")
|
||||
print("Expected: 3 -> double(6) -> square(36) -> add_ten(46)")
|
||||
print("Now with type-safe direct forwarding AND shared state tracking!")
|
||||
print("")
|
||||
|
||||
# Create steps
|
||||
double_step = FunctionStep(
|
||||
step_id="double",
|
||||
metadata=StepMetadata(name="Double Number"),
|
||||
input_type=NumberInput,
|
||||
output_type=NumberOutput,
|
||||
func=double_number
|
||||
)
|
||||
|
||||
square_step = FunctionStep(
|
||||
step_id="square",
|
||||
metadata=StepMetadata(name="Square Number"),
|
||||
input_type=NumberOutput, # Now takes NumberOutput from previous step
|
||||
output_type=NumberOutput,
|
||||
func=square_number
|
||||
)
|
||||
|
||||
add_ten_step = FunctionStep(
|
||||
step_id="add_ten",
|
||||
metadata=StepMetadata(name="Add Ten"),
|
||||
input_type=NumberOutput, # Now takes NumberOutput from previous step
|
||||
output_type=NumberOutput,
|
||||
func=add_ten
|
||||
)
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
metadata=WorkflowMetadata(name="Sequential Example")
|
||||
)
|
||||
|
||||
workflow.add_step(double_step)
|
||||
workflow.add_step(square_step)
|
||||
workflow.add_step(add_ten_step)
|
||||
|
||||
# Create sequence
|
||||
workflow.add_edge("double", "square")
|
||||
workflow.add_edge("square", "add_ten")
|
||||
|
||||
workflow.set_start_step("double")
|
||||
workflow.add_end_step("add_ten")
|
||||
|
||||
# Run workflow
|
||||
runner = WorkflowRunner()
|
||||
initial_input = {"value": 3}
|
||||
|
||||
print(f"\nRunning workflow with input: {initial_input}")
|
||||
execution = await runner.run(workflow, initial_input)
|
||||
|
||||
# Print results
|
||||
print("\n=== Results ===")
|
||||
for step_id, step_exec in execution.step_executions.items():
|
||||
print(f"{step_id}: {step_exec.output_data}")
|
||||
|
||||
print(f"\nFinal result: {execution.step_executions['add_ten'].output_data}")
|
||||
print(f"Expected: {{'result': 46}}")
|
||||
|
||||
# Show shared workflow state
|
||||
print("\n=== Shared Workflow State ===")
|
||||
workflow_summary = execution.state.get('workflow_summary', {})
|
||||
if workflow_summary:
|
||||
print(f"Original input: {workflow_summary.get('original_input')}")
|
||||
print(f"Operations performed: {' → '.join(workflow_summary.get('operations', []))}")
|
||||
print(f"Intermediate results: {workflow_summary.get('intermediate_results')}")
|
||||
print(f"Final result: {workflow_summary.get('final_result')}")
|
||||
else:
|
||||
print("No workflow summary found in shared state")
|
||||
|
||||
print(f"\nAll shared state keys: {list(execution.state.keys())}")
|
||||
|
||||
print("\n=== Workflow Serialization Test ===")
|
||||
|
||||
# Test serialization
|
||||
print("1. Serializing workflow...")
|
||||
dumped_config = workflow.dump_component()
|
||||
print(f" Serialized config type: {type(dumped_config)}")
|
||||
print(f" Config provider: {dumped_config.provider}")
|
||||
print(f" Config version: {dumped_config.version}")
|
||||
|
||||
|
||||
|
||||
# Save workflow to json file for UI integration
|
||||
print("4. Saving workflow JSON for UI...")
|
||||
with open("simple_sequential_workflow.json", "w") as f:
|
||||
f.write(dumped_config.model_dump_json(indent=2))
|
||||
print(" ✅ Saved to simple_sequential_workflow.json")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Step implementations for the workflow system.
|
||||
"""
|
||||
|
||||
from ._step import BaseStep, BaseStepConfig, FunctionStep, EchoStep
|
||||
|
||||
__all__ = ["BaseStep", "BaseStepConfig", "FunctionStep", "EchoStep"]
|
||||
@@ -0,0 +1,520 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Generic, Optional, Type, Callable, Tuple
|
||||
from pydantic import BaseModel, create_model
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from autogen_core import Component, ComponentBase
|
||||
|
||||
from ..core._models import InputType, OutputType, StepMetadata, StepStatus, Context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseStepConfig(BaseModel):
|
||||
"""Base configuration that all step configs must inherit from.
|
||||
|
||||
Ensures UI compatibility by requiring type schema information.
|
||||
"""
|
||||
step_id: str
|
||||
metadata: StepMetadata
|
||||
input_type_name: str
|
||||
output_type_name: str
|
||||
input_schema: Dict[str, Any]
|
||||
output_schema: Dict[str, Any]
|
||||
|
||||
|
||||
class BaseStep(ComponentBase[BaseStepConfig], Generic[InputType, OutputType]):
|
||||
"""Base class for all workflow steps with automatic type serialization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_id: str,
|
||||
metadata: StepMetadata,
|
||||
input_type: Type[InputType],
|
||||
output_type: Type[OutputType]
|
||||
):
|
||||
"""Initialize the step.
|
||||
|
||||
Args:
|
||||
step_id: Unique identifier for this step
|
||||
metadata: Step metadata including name, description, etc.
|
||||
input_type: Pydantic model class for input validation
|
||||
output_type: Pydantic model class for output validation
|
||||
"""
|
||||
self.step_id = step_id
|
||||
self.metadata = metadata
|
||||
self.input_type = input_type
|
||||
self.output_type = output_type
|
||||
self._status = StepStatus.PENDING
|
||||
self._start_time: Optional[datetime] = None
|
||||
self._end_time: Optional[datetime] = None
|
||||
self._error: Optional[str] = None
|
||||
|
||||
def _serialize_types(self) -> Dict[str, Any]:
|
||||
"""Serialize input/output types to config data.
|
||||
|
||||
Returns:
|
||||
Dictionary containing type names and schemas for serialization
|
||||
"""
|
||||
return {
|
||||
"step_id": self.step_id,
|
||||
"metadata": self.metadata,
|
||||
"input_type_name": self.input_type.__name__,
|
||||
"output_type_name": self.output_type.__name__,
|
||||
"input_schema": self.input_type.model_json_schema(),
|
||||
"output_schema": self.output_type.model_json_schema()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _deserialize_types(cls, config: BaseStepConfig) -> Tuple[Type[InputType], Type[OutputType]]:
|
||||
"""Deserialize input/output types from config data using Pydantic's create_model.
|
||||
|
||||
Args:
|
||||
config: Step configuration with embedded schemas
|
||||
|
||||
Returns:
|
||||
Tuple of (input_type, output_type) recreated from schemas
|
||||
"""
|
||||
from typing import List, Dict, Any as AnyType
|
||||
|
||||
def schema_to_field_definitions(schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert JSON schema to create_model field definitions."""
|
||||
properties = schema.get('properties', {})
|
||||
required_fields = set(schema.get('required', []))
|
||||
field_definitions = {}
|
||||
|
||||
# Type mapping for JSON schema to Python types
|
||||
type_map = {
|
||||
'string': str,
|
||||
'integer': int,
|
||||
'number': float,
|
||||
'boolean': bool,
|
||||
'array': List[AnyType], # Simplified
|
||||
'object': Dict[str, AnyType]
|
||||
}
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
json_type = field_schema.get('type', 'string')
|
||||
python_type = type_map.get(json_type, str)
|
||||
|
||||
if field_name in required_fields:
|
||||
# For required fields, use (type, ...) format
|
||||
field_definitions[field_name] = (python_type, ...)
|
||||
else:
|
||||
default_value = field_schema.get('default', None)
|
||||
field_definitions[field_name] = (python_type, default_value)
|
||||
|
||||
return field_definitions
|
||||
|
||||
# Extract field definitions from schemas
|
||||
input_fields = schema_to_field_definitions(config.input_schema)
|
||||
output_fields = schema_to_field_definitions(config.output_schema)
|
||||
|
||||
# Use create_model directly with the field definitions
|
||||
input_type = create_model(config.input_type_name, **input_fields)
|
||||
output_type = create_model(config.output_type_name, **output_fields)
|
||||
|
||||
return input_type, output_type
|
||||
|
||||
@property
|
||||
def status(self) -> StepStatus:
|
||||
"""Get current step status."""
|
||||
return self._status
|
||||
|
||||
@property
|
||||
def start_time(self) -> Optional[datetime]:
|
||||
"""Get step start time."""
|
||||
return self._start_time
|
||||
|
||||
@property
|
||||
def end_time(self) -> Optional[datetime]:
|
||||
"""Get step end time."""
|
||||
return self._end_time
|
||||
|
||||
@property
|
||||
def error(self) -> Optional[str]:
|
||||
"""Get step error if any."""
|
||||
return self._error
|
||||
|
||||
@property
|
||||
def duration(self) -> Optional[float]:
|
||||
"""Get step duration in seconds."""
|
||||
if self._start_time and self._end_time:
|
||||
return (self._end_time - self._start_time).total_seconds()
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, input_data: InputType, context: Context) -> OutputType:
|
||||
"""Execute the step logic.
|
||||
|
||||
Args:
|
||||
input_data: Validated input data
|
||||
context: Additional context including workflow state
|
||||
|
||||
Returns:
|
||||
Validated output data
|
||||
|
||||
Raises:
|
||||
Exception: If step execution fails
|
||||
"""
|
||||
pass
|
||||
|
||||
async def run(self, input_data: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Run the step with input validation and error handling.
|
||||
|
||||
Args:
|
||||
input_data: Raw input data to validate
|
||||
context: Additional context including workflow state
|
||||
|
||||
Returns:
|
||||
Dictionary containing output data
|
||||
|
||||
Raises:
|
||||
Exception: If step execution fails after retries
|
||||
"""
|
||||
logger.info(f"Starting step {self.step_id} ({self.metadata.name})")
|
||||
|
||||
self._status = StepStatus.RUNNING
|
||||
self._start_time = datetime.now()
|
||||
self._error = None
|
||||
|
||||
retry_count = 0
|
||||
max_retries = self.metadata.max_retries
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# Validate input
|
||||
validated_input = self.input_type(**input_data)
|
||||
|
||||
# Create typed context from dict
|
||||
if isinstance(context, dict):
|
||||
workflow_state = context.get('workflow_state', {})
|
||||
# Use from_state_ref to avoid copying the state dict
|
||||
typed_context = Context.from_state_ref(workflow_state)
|
||||
else:
|
||||
typed_context = context
|
||||
|
||||
# Execute with timeout if specified
|
||||
if self.metadata.timeout_seconds:
|
||||
output = await asyncio.wait_for(
|
||||
self.execute(validated_input, typed_context),
|
||||
timeout=self.metadata.timeout_seconds
|
||||
)
|
||||
else:
|
||||
output = await self.execute(validated_input, typed_context)
|
||||
|
||||
# Validate output
|
||||
if not isinstance(output, self.output_type):
|
||||
if hasattr(output, 'model_dump'):
|
||||
output = self.output_type(**output.model_dump())
|
||||
elif isinstance(output, dict):
|
||||
output = self.output_type(**output)
|
||||
else:
|
||||
# Try to convert to dict if possible
|
||||
output = self.output_type(result=output)
|
||||
|
||||
self._status = StepStatus.COMPLETED
|
||||
self._end_time = datetime.now()
|
||||
|
||||
logger.info(f"Step {self.step_id} completed successfully in {self.duration:.2f}s")
|
||||
return output.model_dump()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"Step {self.step_id} timed out after {self.metadata.timeout_seconds}s"
|
||||
logger.error(error_msg)
|
||||
self._error = error_msg
|
||||
self._status = StepStatus.FAILED
|
||||
self._end_time = datetime.now()
|
||||
raise Exception(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
error_msg = f"Step {self.step_id} failed (attempt {retry_count}/{max_retries + 1}): {str(e)}"
|
||||
logger.error(error_msg)
|
||||
|
||||
if retry_count <= max_retries:
|
||||
logger.info(f"Retrying step {self.step_id} in 1 second...")
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
self._error = str(e)
|
||||
self._status = StepStatus.FAILED
|
||||
self._end_time = datetime.now()
|
||||
raise
|
||||
|
||||
# Should never reach here
|
||||
raise Exception(f"Unexpected error in step {self.step_id}")
|
||||
|
||||
def validate_input(self, data: Dict[str, Any]) -> bool:
|
||||
"""Validate input data against the input schema.
|
||||
|
||||
Args:
|
||||
data: Input data to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.input_type(**data)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def validate_output(self, data: Dict[str, Any]) -> bool:
|
||||
"""Validate output data against the output schema.
|
||||
|
||||
Args:
|
||||
data: Output data to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.output_type(**data)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_schema(self) -> Dict[str, Any]:
|
||||
"""Get the input/output schema for this step.
|
||||
|
||||
Returns:
|
||||
Dictionary containing input and output schemas
|
||||
"""
|
||||
return {
|
||||
"step_id": self.step_id,
|
||||
"metadata": self.metadata.model_dump(),
|
||||
"input_type": self.input_type.__name__,
|
||||
"output_type": self.output_type.__name__,
|
||||
"input_schema": self.input_type.model_json_schema(),
|
||||
"output_schema": self.output_type.model_json_schema()
|
||||
}
|
||||
|
||||
|
||||
class FunctionStepConfig(BaseStepConfig):
|
||||
"""Configuration for FunctionStep serialization."""
|
||||
# Base fields inherited: step_id, metadata, input_type_name, output_type_name, input_schema, output_schema
|
||||
# Note: We can't easily serialize functions, so we'll store a reference
|
||||
function_name: Optional[str] = None
|
||||
function_module: Optional[str] = None
|
||||
|
||||
|
||||
class FunctionStep(Component[FunctionStepConfig], BaseStep[InputType, OutputType]):
|
||||
"""A step that executes a function as its core operation."""
|
||||
|
||||
component_config_schema = FunctionStepConfig
|
||||
component_type = "step"
|
||||
component_provider_override = "autogenstudio.workflow.steps.FunctionStep"
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_id: str,
|
||||
metadata: StepMetadata,
|
||||
input_type: Type[InputType],
|
||||
output_type: Type[OutputType],
|
||||
func: Callable
|
||||
):
|
||||
"""Initialize with a function to execute.
|
||||
|
||||
Args:
|
||||
step_id: Unique identifier for this step
|
||||
metadata: Step metadata
|
||||
input_type: Input validation model
|
||||
output_type: Output validation model
|
||||
func: Function to execute (can be sync or async)
|
||||
"""
|
||||
super().__init__(step_id, metadata, input_type, output_type)
|
||||
self.func = func
|
||||
|
||||
async def execute(self, input_data: InputType, context: Context) -> OutputType:
|
||||
"""Execute the wrapped function.
|
||||
|
||||
Args:
|
||||
input_data: Validated input data
|
||||
context: Additional context
|
||||
|
||||
Returns:
|
||||
Function output
|
||||
"""
|
||||
if asyncio.iscoroutinefunction(self.func):
|
||||
result = await self.func(input_data, context)
|
||||
else:
|
||||
result = self.func(input_data, context)
|
||||
|
||||
if isinstance(result, dict):
|
||||
return self.output_type(**result)
|
||||
elif hasattr(result, 'dict'):
|
||||
return result
|
||||
else:
|
||||
# Assume it's a simple value that can be wrapped
|
||||
return self.output_type(result=result)
|
||||
|
||||
def _to_config(self) -> FunctionStepConfig:
|
||||
"""Convert step to configuration for serialization."""
|
||||
func_name = None
|
||||
func_module = None
|
||||
|
||||
if hasattr(self.func, '__name__'):
|
||||
func_name = self.func.__name__
|
||||
if hasattr(self.func, '__module__'):
|
||||
func_module = self.func.__module__
|
||||
|
||||
# Get base type serialization data
|
||||
base_data = self._serialize_types()
|
||||
|
||||
return FunctionStepConfig(
|
||||
**base_data,
|
||||
function_name=func_name,
|
||||
function_module=func_module
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: FunctionStepConfig) -> "FunctionStep":
|
||||
"""Create step from configuration.
|
||||
|
||||
Args:
|
||||
config: Step configuration
|
||||
|
||||
Note:
|
||||
This basic implementation cannot recreate the function.
|
||||
In practice, you'd need a function registry or other mechanism
|
||||
to deserialize callable functions.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"FunctionStep deserialization is not fully supported as functions "
|
||||
"cannot be easily serialized. Consider using a function registry "
|
||||
"or other mechanism for this use case."
|
||||
)
|
||||
|
||||
|
||||
class EchoStepConfig(BaseStepConfig):
|
||||
"""Configuration for EchoStep serialization."""
|
||||
# Base fields inherited: step_id, metadata, input_type_name, output_type_name, input_schema, output_schema
|
||||
prefix: str = "Echo: "
|
||||
suffix: str = ""
|
||||
|
||||
|
||||
class EchoStep(Component[EchoStepConfig], BaseStep[InputType, OutputType]):
|
||||
"""A simple step that echoes input with prefix/suffix - fully serializable."""
|
||||
|
||||
component_config_schema = EchoStepConfig
|
||||
component_type = "step"
|
||||
component_provider_override = "autogenstudio.workflow.steps.EchoStep"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_id: str,
|
||||
metadata: StepMetadata,
|
||||
input_type: Type[InputType],
|
||||
output_type: Type[OutputType],
|
||||
prefix: str = "Echo: ",
|
||||
suffix: str = ""
|
||||
):
|
||||
"""Initialize the echo step.
|
||||
|
||||
Args:
|
||||
step_id: Unique identifier for this step
|
||||
metadata: Step metadata
|
||||
input_type: Pydantic model class for input validation
|
||||
output_type: Pydantic model class for output validation
|
||||
prefix: String to prepend to input
|
||||
suffix: String to append to input
|
||||
"""
|
||||
super().__init__(step_id, metadata, input_type, output_type)
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
|
||||
async def execute(self, input_data: InputType, context: Context) -> OutputType:
|
||||
"""Execute the echo operation.
|
||||
|
||||
Args:
|
||||
input_data: Input data (must have a message, result, text, or other field)
|
||||
context: Workflow context
|
||||
|
||||
Returns:
|
||||
Output with echoed message
|
||||
"""
|
||||
# Try to get the message from different possible field names
|
||||
message = None
|
||||
|
||||
# Try common field names
|
||||
for field_name in ['message', 'result', 'text', 'content', 'data']:
|
||||
if hasattr(input_data, field_name):
|
||||
message = getattr(input_data, field_name)
|
||||
break
|
||||
|
||||
# If no common field found, try the first field
|
||||
if message is None:
|
||||
field_names = list(input_data.model_fields.keys())
|
||||
if field_names:
|
||||
message = getattr(input_data, field_names[0])
|
||||
else:
|
||||
# Fall back to string representation
|
||||
message = str(input_data)
|
||||
|
||||
result = f"{self.prefix}{message}{self.suffix}"
|
||||
|
||||
# Store echo operation in context
|
||||
context.set(f'{self.step_id}_echo_info', {
|
||||
'original': message,
|
||||
'prefix': self.prefix,
|
||||
'suffix': self.suffix,
|
||||
'result': result
|
||||
})
|
||||
|
||||
# Create output - try different field names
|
||||
output_fields = list(self.output_type.model_fields.keys())
|
||||
if 'result' in output_fields:
|
||||
return self.output_type(result=result)
|
||||
elif 'message' in output_fields:
|
||||
return self.output_type(message=result)
|
||||
elif 'text' in output_fields:
|
||||
return self.output_type(text=result)
|
||||
elif 'response' in output_fields:
|
||||
return self.output_type(response=result)
|
||||
elif 'content' in output_fields:
|
||||
return self.output_type(content=result)
|
||||
else:
|
||||
# Fall back to first field
|
||||
field_name = output_fields[0]
|
||||
return self.output_type(**{field_name: result})
|
||||
|
||||
def _to_config(self) -> EchoStepConfig:
|
||||
"""Convert step to configuration for serialization."""
|
||||
# Get base type serialization data
|
||||
base_data = self._serialize_types()
|
||||
|
||||
return EchoStepConfig(
|
||||
**base_data,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: EchoStepConfig) -> "EchoStep":
|
||||
"""Create step from configuration using shared schema-based deserialization.
|
||||
|
||||
Args:
|
||||
config: Step configuration with embedded schemas
|
||||
|
||||
Returns:
|
||||
Recreated EchoStep instance with dynamically created types
|
||||
"""
|
||||
# Use shared type deserialization
|
||||
input_type, output_type = cls._deserialize_types(config)
|
||||
|
||||
return cls(
|
||||
step_id=config.step_id,
|
||||
metadata=config.metadata,
|
||||
input_type=input_type,
|
||||
output_type=output_type,
|
||||
prefix=config.prefix,
|
||||
suffix=config.suffix
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# Tests for workflow process system
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Pytest configuration for workflow process tests.
|
||||
"""
|
||||
import pytest
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
# Configure logging for tests
|
||||
logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
@@ -0,0 +1,276 @@
|
||||
"""
|
||||
Test workflow type validation and serialization.
|
||||
"""
|
||||
import pytest
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogenstudio.workflow.steps import EchoStep
|
||||
from autogenstudio.workflow.core import Workflow, WorkflowRunner, StepMetadata, WorkflowMetadata
|
||||
|
||||
|
||||
# Test data models
|
||||
class TextInput(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class TextOutput(BaseModel):
|
||||
result: str
|
||||
|
||||
|
||||
class NumberInput(BaseModel):
|
||||
value: int
|
||||
|
||||
|
||||
class NumberOutput(BaseModel):
|
||||
result: int
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow():
|
||||
"""Create a sample workflow for testing."""
|
||||
workflow = Workflow(
|
||||
metadata=WorkflowMetadata(
|
||||
name="Test Workflow",
|
||||
version="1.0.0"
|
||||
)
|
||||
)
|
||||
|
||||
step1 = EchoStep(
|
||||
step_id="step1",
|
||||
metadata=StepMetadata(name="First Step"),
|
||||
input_type=TextInput,
|
||||
output_type=TextOutput,
|
||||
prefix="[1] ",
|
||||
suffix=""
|
||||
)
|
||||
|
||||
step2 = EchoStep(
|
||||
step_id="step2",
|
||||
metadata=StepMetadata(name="Second Step"),
|
||||
input_type=TextOutput, # Compatible with step1 output
|
||||
output_type=TextOutput,
|
||||
prefix="[2] ",
|
||||
suffix=" (done)"
|
||||
)
|
||||
|
||||
workflow.add_step(step1).add_step(step2)
|
||||
workflow.add_edge("step1", "step2")
|
||||
workflow.set_start_step("step1").add_end_step("step2")
|
||||
|
||||
return workflow
|
||||
|
||||
|
||||
class TestWorkflowValidation:
|
||||
"""Test workflow validation, especially type checking."""
|
||||
|
||||
def test_compatible_types_pass_validation(self, sample_workflow):
|
||||
"""Test that compatible types pass validation."""
|
||||
validation = sample_workflow.validate_workflow()
|
||||
assert validation.is_valid
|
||||
assert len(validation.errors) == 0
|
||||
|
||||
def test_incompatible_types_fail_validation(self):
|
||||
"""Test that incompatible types fail validation."""
|
||||
workflow = Workflow(
|
||||
metadata=WorkflowMetadata(name="Invalid Workflow", version="1.0.0")
|
||||
)
|
||||
|
||||
# Create steps with incompatible types
|
||||
step1 = EchoStep(
|
||||
step_id="text_step",
|
||||
metadata=StepMetadata(name="Text Step"),
|
||||
input_type=TextInput,
|
||||
output_type=TextOutput, # Outputs TextOutput
|
||||
prefix="Text: ",
|
||||
suffix=""
|
||||
)
|
||||
|
||||
step2 = EchoStep(
|
||||
step_id="number_step",
|
||||
metadata=StepMetadata(name="Number Step"),
|
||||
input_type=NumberInput, # Expects NumberInput (incompatible!)
|
||||
output_type=NumberOutput,
|
||||
prefix="Number: ",
|
||||
suffix=""
|
||||
)
|
||||
|
||||
workflow.add_step(step1).add_step(step2)
|
||||
workflow.add_edge("text_step", "number_step") # This should fail validation
|
||||
workflow.set_start_step("text_step").add_end_step("number_step")
|
||||
|
||||
validation = workflow.validate_workflow()
|
||||
assert not validation.is_valid
|
||||
assert len(validation.errors) > 0
|
||||
assert any("Type mismatch" in error for error in validation.errors)
|
||||
|
||||
def test_schema_based_validation_works(self):
|
||||
"""Test that schema-based validation correctly identifies compatible types."""
|
||||
workflow = Workflow(
|
||||
metadata=WorkflowMetadata(name="Schema Test", version="1.0.0")
|
||||
)
|
||||
|
||||
# Create two steps that use the same schema but different type instances
|
||||
step1 = EchoStep(
|
||||
step_id="step1",
|
||||
metadata=StepMetadata(name="Step 1"),
|
||||
input_type=TextInput,
|
||||
output_type=TextOutput,
|
||||
prefix="[1] ",
|
||||
suffix=""
|
||||
)
|
||||
|
||||
# This step also uses TextOutput as input - should be compatible by schema
|
||||
step2 = EchoStep(
|
||||
step_id="step2",
|
||||
metadata=StepMetadata(name="Step 2"),
|
||||
input_type=TextOutput,
|
||||
output_type=TextOutput,
|
||||
prefix="[2] ",
|
||||
suffix=""
|
||||
)
|
||||
|
||||
workflow.add_step(step1).add_step(step2)
|
||||
workflow.add_edge("step1", "step2")
|
||||
workflow.set_start_step("step1").add_end_step("step2")
|
||||
|
||||
validation = workflow.validate_workflow()
|
||||
assert validation.is_valid, f"Validation failed: {validation.errors}"
|
||||
|
||||
|
||||
class TestWorkflowSerialization:
|
||||
"""Test workflow serialization and deserialization."""
|
||||
|
||||
def test_workflow_serialization_roundtrip(self, sample_workflow):
|
||||
"""Test that workflow can be serialized and deserialized correctly."""
|
||||
# Serialize
|
||||
config = sample_workflow.dump_component()
|
||||
assert config.provider == "autogenstudio.workflow.core.Workflow"
|
||||
assert len(config.config['steps']) == 2
|
||||
|
||||
# Deserialize
|
||||
new_workflow = Workflow.load_component(config)
|
||||
assert new_workflow.metadata.name == sample_workflow.metadata.name
|
||||
assert len(new_workflow.steps) == 2
|
||||
assert new_workflow.start_step_id == sample_workflow.start_step_id
|
||||
assert new_workflow.end_step_ids == sample_workflow.end_step_ids
|
||||
|
||||
# Validate deserialized workflow
|
||||
validation = new_workflow.validate_workflow()
|
||||
assert validation.is_valid, f"Deserialized workflow validation failed: {validation.errors}"
|
||||
|
||||
def test_serialized_workflow_execution_matches_original(self, sample_workflow):
|
||||
"""Test that serialized workflow produces same results as original."""
|
||||
# Run original workflow
|
||||
runner = WorkflowRunner()
|
||||
input_data = {"message": "test"}
|
||||
|
||||
import asyncio
|
||||
|
||||
async def run_test():
|
||||
result1 = await runner.run(sample_workflow, input_data)
|
||||
|
||||
# Serialize and deserialize
|
||||
config = sample_workflow.dump_component()
|
||||
new_workflow = Workflow.load_component(config)
|
||||
|
||||
# Run deserialized workflow
|
||||
runner2 = WorkflowRunner()
|
||||
result2 = await runner2.run(new_workflow, input_data)
|
||||
|
||||
# Compare final outputs
|
||||
def get_final_output(result):
|
||||
for step_id, step_exec in result.step_executions.items():
|
||||
if step_id in new_workflow.end_step_ids:
|
||||
return step_exec.output_data
|
||||
return None
|
||||
|
||||
output1 = get_final_output(result1)
|
||||
output2 = get_final_output(result2)
|
||||
|
||||
assert output1 == output2, f"Outputs don't match: {output1} vs {output2}"
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
def test_serialization_preserves_step_configuration(self, sample_workflow):
|
||||
"""Test that step configuration is preserved through serialization."""
|
||||
# Serialize and deserialize
|
||||
config = sample_workflow.dump_component()
|
||||
new_workflow = Workflow.load_component(config)
|
||||
|
||||
# Check that step configurations are preserved
|
||||
original_step1 = sample_workflow.steps["step1"]
|
||||
new_step1 = new_workflow.steps["step1"]
|
||||
|
||||
assert original_step1.step_id == new_step1.step_id
|
||||
assert original_step1.metadata.name == new_step1.metadata.name
|
||||
assert original_step1.prefix == new_step1.prefix
|
||||
assert original_step1.suffix == new_step1.suffix
|
||||
|
||||
|
||||
class TestWorkflowBasics:
|
||||
"""Test basic workflow functionality."""
|
||||
|
||||
def test_workflow_creation(self):
|
||||
"""Test basic workflow creation."""
|
||||
workflow = Workflow(
|
||||
metadata=WorkflowMetadata(name="Test", version="1.0.0")
|
||||
)
|
||||
assert workflow.metadata.name == "Test"
|
||||
assert len(workflow.steps) == 0
|
||||
assert len(workflow.edges) == 0
|
||||
|
||||
def test_step_addition(self):
|
||||
"""Test adding steps to workflow."""
|
||||
workflow = Workflow(
|
||||
metadata=WorkflowMetadata(name="Test", version="1.0.0")
|
||||
)
|
||||
|
||||
step = EchoStep(
|
||||
step_id="test_step",
|
||||
metadata=StepMetadata(name="Test Step"),
|
||||
input_type=TextInput,
|
||||
output_type=TextOutput,
|
||||
prefix="Test: ",
|
||||
suffix=""
|
||||
)
|
||||
|
||||
workflow.add_step(step)
|
||||
assert "test_step" in workflow.steps
|
||||
assert workflow.steps["test_step"] is step
|
||||
|
||||
def test_edge_addition(self):
|
||||
"""Test adding edges between steps."""
|
||||
workflow = Workflow(
|
||||
metadata=WorkflowMetadata(name="Test", version="1.0.0")
|
||||
)
|
||||
|
||||
step1 = EchoStep(
|
||||
step_id="step1",
|
||||
metadata=StepMetadata(name="Step 1"),
|
||||
input_type=TextInput,
|
||||
output_type=TextOutput,
|
||||
prefix="1: ",
|
||||
suffix=""
|
||||
)
|
||||
|
||||
step2 = EchoStep(
|
||||
step_id="step2",
|
||||
metadata=StepMetadata(name="Step 2"),
|
||||
input_type=TextOutput,
|
||||
output_type=TextOutput,
|
||||
prefix="2: ",
|
||||
suffix=""
|
||||
)
|
||||
|
||||
workflow.add_step(step1).add_step(step2)
|
||||
workflow.add_edge("step1", "step2")
|
||||
|
||||
assert len(workflow.edges) == 1
|
||||
assert workflow.edges[0].from_step == "step1"
|
||||
assert workflow.edges[0].to_step == "step2"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Allow running as script for quick testing
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -0,0 +1,248 @@
|
||||
import {
|
||||
Workflow,
|
||||
Step,
|
||||
StepLibrary,
|
||||
ApiResponse,
|
||||
CreateWorkflowRequest,
|
||||
UpdateWorkflowRequest,
|
||||
WorkflowConfig,
|
||||
WorkflowRun,
|
||||
} from "./types";
|
||||
|
||||
// Mock data for development
|
||||
const mockSteps: Step[] = [
|
||||
{
|
||||
id: "step-1",
|
||||
name: "Research Assistant",
|
||||
description: "A step that can research topics and gather information",
|
||||
type: "agent_step",
|
||||
system_message:
|
||||
"You are a helpful research assistant. Gather comprehensive information on given topics.",
|
||||
tools: ["web_search", "document_analysis"],
|
||||
model: "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
id: "step-2",
|
||||
name: "Data Analyst",
|
||||
description: "A step specialized in analyzing data and creating insights",
|
||||
type: "agent_step",
|
||||
system_message:
|
||||
"You are a data analyst. Analyze data and provide clear insights and visualizations.",
|
||||
tools: ["python_executor", "chart_generator"],
|
||||
model: "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
id: "step-3",
|
||||
name: "Content Writer",
|
||||
description: "A step that creates well-written content based on research",
|
||||
type: "agent_step",
|
||||
system_message:
|
||||
"You are a professional content writer. Create engaging, well-structured content.",
|
||||
tools: ["text_editor"],
|
||||
model: "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
id: "step-4",
|
||||
name: "Code Reviewer",
|
||||
description: "A step that reviews code for quality and best practices",
|
||||
type: "agent_step",
|
||||
system_message:
|
||||
"You are a senior software engineer. Review code for quality, security, and best practices.",
|
||||
tools: ["code_analysis", "security_scanner"],
|
||||
model: "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
id: "step-5",
|
||||
name: "Summary Step",
|
||||
description:
|
||||
"A step that takes outputs from other steps and generates comprehensive summaries",
|
||||
type: "agent_step",
|
||||
system_message:
|
||||
"You are a summary specialist. Take outputs from multiple steps and create clear, concise, and comprehensive summaries that highlight key findings, insights, and recommendations.",
|
||||
tools: ["text_processor", "content_aggregator"],
|
||||
model: "gpt-4o-mini",
|
||||
},
|
||||
];
|
||||
|
||||
const mockStepLibrary: StepLibrary = {
|
||||
name: "Default Library",
|
||||
description: "A collection of default steps for AutoGen Studio",
|
||||
steps: mockSteps,
|
||||
};
|
||||
|
||||
const mockWorkflows: Workflow[] = [
|
||||
{
|
||||
id: "workflow-1",
|
||||
name: "Research and Analysis Pipeline",
|
||||
description:
|
||||
"A workflow that researches a topic, analyzes data, creates content, and generates a summary",
|
||||
created_at: "2025-01-10T10:00:00Z",
|
||||
updated_at: "2025-01-10T10:00:00Z",
|
||||
config: {
|
||||
id: "workflow-1",
|
||||
name: "Research and Analysis Pipeline",
|
||||
description:
|
||||
"Sequential workflow for research, analysis, content creation, and summarization",
|
||||
steps: [mockSteps[0], mockSteps[1], mockSteps[2], mockSteps[4]],
|
||||
edges: [
|
||||
{
|
||||
id: "edge-1",
|
||||
from_step: "step-1",
|
||||
to_step: "step-2",
|
||||
},
|
||||
{
|
||||
id: "edge-2",
|
||||
from_step: "step-2",
|
||||
to_step: "step-3",
|
||||
},
|
||||
{
|
||||
id: "edge-3",
|
||||
from_step: "step-3",
|
||||
to_step: "step-5",
|
||||
},
|
||||
],
|
||||
start_step_id: "step-1",
|
||||
} as WorkflowConfig,
|
||||
},
|
||||
{
|
||||
id: "workflow-2",
|
||||
name: "Code Generation and Review",
|
||||
description:
|
||||
"A workflow that generates code, reviews it, and prepares it for deployment",
|
||||
created_at: "2025-01-12T14:30:00Z",
|
||||
updated_at: "2025-01-12T14:30:00Z",
|
||||
config: {
|
||||
id: "workflow-2",
|
||||
name: "Code Generation and Review",
|
||||
description: "Generates and reviews code for quality assurance",
|
||||
steps: [mockSteps[2], mockSteps[3]],
|
||||
edges: [
|
||||
{
|
||||
id: "edge-4",
|
||||
from_step: "step-2",
|
||||
to_step: "step-3",
|
||||
},
|
||||
],
|
||||
start_step_id: "step-2",
|
||||
} as WorkflowConfig,
|
||||
},
|
||||
];
|
||||
|
||||
// Simulate API latency
|
||||
const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||
|
||||
export const workflowAPI = {
|
||||
getWorkflows: async (): Promise<ApiResponse<Workflow[]>> => {
|
||||
await sleep(500);
|
||||
console.log("API: getWorkflows", mockWorkflows);
|
||||
return { success: true, data: mockWorkflows };
|
||||
},
|
||||
|
||||
getWorkflow: async (id: string): Promise<ApiResponse<Workflow>> => {
|
||||
await sleep(500);
|
||||
const workflow = mockWorkflows.find((w) => w.id === id);
|
||||
if (workflow) {
|
||||
console.log("API: getWorkflow", workflow);
|
||||
return { success: true, data: workflow };
|
||||
}
|
||||
return {
|
||||
success: false,
|
||||
data: {} as Workflow,
|
||||
message: "Workflow not found",
|
||||
};
|
||||
},
|
||||
|
||||
getSteps: async (): Promise<ApiResponse<Step[]>> => {
|
||||
await sleep(300);
|
||||
console.log("API: getSteps", mockSteps);
|
||||
return { success: true, data: mockSteps };
|
||||
},
|
||||
|
||||
getStepLibrary: async (): Promise<ApiResponse<StepLibrary>> => {
|
||||
await sleep(300);
|
||||
console.log("API: getStepLibrary", mockStepLibrary);
|
||||
return { success: true, data: mockStepLibrary };
|
||||
},
|
||||
|
||||
createWorkflow: async (
|
||||
data: CreateWorkflowRequest
|
||||
): Promise<ApiResponse<Workflow>> => {
|
||||
await sleep(500);
|
||||
const newWorkflow: Workflow = {
|
||||
id: `workflow-${Date.now()}`,
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
...data,
|
||||
config: {
|
||||
...data.config,
|
||||
id: `config-${Date.now()}`,
|
||||
},
|
||||
};
|
||||
mockWorkflows.push(newWorkflow);
|
||||
console.log("API: createWorkflow", newWorkflow);
|
||||
return { success: true, data: newWorkflow };
|
||||
},
|
||||
|
||||
updateWorkflow: async (
|
||||
id: string,
|
||||
data: UpdateWorkflowRequest
|
||||
): Promise<ApiResponse<Workflow>> => {
|
||||
await sleep(500);
|
||||
const index = mockWorkflows.findIndex((w) => w.id === id);
|
||||
if (index !== -1) {
|
||||
const updatedWorkflow = {
|
||||
...mockWorkflows[index],
|
||||
...data,
|
||||
config: {
|
||||
...mockWorkflows[index].config,
|
||||
...data.config,
|
||||
},
|
||||
updated_at: new Date().toISOString(),
|
||||
};
|
||||
mockWorkflows[index] = updatedWorkflow;
|
||||
console.log("API: updateWorkflow", updatedWorkflow);
|
||||
return { success: true, data: updatedWorkflow };
|
||||
}
|
||||
return {
|
||||
success: false,
|
||||
data: {} as Workflow,
|
||||
message: "Workflow not found",
|
||||
};
|
||||
},
|
||||
|
||||
deleteWorkflow: async (id: string): Promise<ApiResponse<boolean>> => {
|
||||
await sleep(500);
|
||||
const index = mockWorkflows.findIndex((w) => w.id === id);
|
||||
if (index !== -1) {
|
||||
mockWorkflows.splice(index, 1);
|
||||
console.log("API: deleteWorkflow", id);
|
||||
return { success: true, data: true };
|
||||
}
|
||||
return { success: false, data: false, message: "Workflow not found" };
|
||||
},
|
||||
|
||||
runWorkflow: async (
|
||||
workflowId: string,
|
||||
input: Record<string, any>
|
||||
): Promise<ApiResponse<WorkflowRun>> => {
|
||||
await sleep(1500);
|
||||
console.log("API: runWorkflow", workflowId, input);
|
||||
const run: WorkflowRun = {
|
||||
id: `run-${Date.now()}`,
|
||||
workflow_id: workflowId,
|
||||
status: "completed",
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
inputs: input,
|
||||
outputs: [
|
||||
{
|
||||
step_id: "step-5",
|
||||
output: {
|
||||
summary: "This is a mock summary of the research and analysis.",
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
return { success: true, data: run };
|
||||
},
|
||||
};
|
||||
@@ -0,0 +1,443 @@
|
||||
import React, { useCallback, useEffect, useState } from "react";
|
||||
import {
|
||||
ReactFlow,
|
||||
useNodesState,
|
||||
useEdgesState,
|
||||
addEdge,
|
||||
Connection,
|
||||
Background,
|
||||
MiniMap,
|
||||
Panel,
|
||||
Node,
|
||||
Edge,
|
||||
BackgroundVariant,
|
||||
NodeProps,
|
||||
NodeTypes,
|
||||
} from "@xyflow/react";
|
||||
import "@xyflow/react/dist/style.css";
|
||||
import { message, Drawer, Button } from "antd";
|
||||
import { Workflow, Step, NodeData } from "./types";
|
||||
import { StepLibrary } from "./library";
|
||||
import { StepNode } from "./nodes";
|
||||
import { Toolbar } from "./toolbar";
|
||||
import {
|
||||
convertToReactFlowNodes,
|
||||
convertToReactFlowEdges,
|
||||
addStepToWorkflow,
|
||||
saveNodePosition,
|
||||
removeNodePosition,
|
||||
calculateNodePosition,
|
||||
} from "./utils";
|
||||
|
||||
// Custom node types
|
||||
const nodeTypes: NodeTypes = {
|
||||
step: StepNode,
|
||||
};
|
||||
|
||||
interface WorkflowBuilderProps {
|
||||
workflow: Workflow;
|
||||
onChange?: (workflow: Partial<Workflow>) => void;
|
||||
onSave?: (workflow: Partial<Workflow>) => void;
|
||||
onDirtyStateChange?: (isDirty: boolean) => void;
|
||||
}
|
||||
|
||||
export const WorkflowBuilder: React.FC<WorkflowBuilderProps> = ({
|
||||
workflow,
|
||||
onChange,
|
||||
onSave,
|
||||
onDirtyStateChange,
|
||||
}) => {
|
||||
const [nodes, setNodes, onNodesChange] = useNodesState<Node<NodeData>>([]);
|
||||
const [edges, setEdges, onEdgesChange] = useEdgesState<Edge>([]);
|
||||
const [isLibraryCompact, setIsLibraryCompact] = useState(false);
|
||||
const [showMiniMap, setShowMiniMap] = useState(false);
|
||||
const [showGrid, setShowGrid] = useState(true);
|
||||
const [isDirty, setIsDirty] = useState(false);
|
||||
const [selectedStep, setSelectedStep] = useState<Step | null>(null);
|
||||
const [stepDrawerOpen, setStepDrawerOpen] = useState(false);
|
||||
const [edgeType, setEdgeType] = useState<string>("smoothstep");
|
||||
|
||||
const [messageApi, contextHolder] = message.useMessage();
|
||||
|
||||
// Notify parent of dirty state changes
|
||||
useEffect(() => {
|
||||
onDirtyStateChange?.(isDirty);
|
||||
}, [isDirty, onDirtyStateChange]);
|
||||
|
||||
const onConnect = useCallback(
|
||||
(params: Connection) => {
|
||||
setEdges((eds) =>
|
||||
addEdge(
|
||||
{
|
||||
...params,
|
||||
type: edgeType,
|
||||
},
|
||||
eds
|
||||
)
|
||||
);
|
||||
|
||||
const newEdge = {
|
||||
id: `edge-${params.source}-${params.target}-${Date.now()}`,
|
||||
from_step: params.source!,
|
||||
to_step: params.target!,
|
||||
};
|
||||
|
||||
const updatedConfig = {
|
||||
...workflow.config,
|
||||
edges: [...workflow.config.edges, newEdge],
|
||||
};
|
||||
|
||||
onChange?.({
|
||||
...workflow,
|
||||
config: updatedConfig,
|
||||
});
|
||||
|
||||
setIsDirty(true);
|
||||
messageApi.success("Edge added successfully");
|
||||
},
|
||||
[workflow.config, setEdges, onChange, messageApi, edgeType]
|
||||
);
|
||||
|
||||
const onNodeDragStop = useCallback(
|
||||
(event: any, node: Node) => {
|
||||
if (workflow.id) {
|
||||
saveNodePosition(workflow.id, node.id, node.position);
|
||||
setIsDirty(true);
|
||||
}
|
||||
},
|
||||
[workflow.id]
|
||||
);
|
||||
|
||||
const handleAddStep = useCallback(
|
||||
(step: Step) => {
|
||||
const position = calculateNodePosition(
|
||||
workflow.config.steps.length,
|
||||
workflow.config.steps.length + 1
|
||||
);
|
||||
|
||||
if (workflow.id) {
|
||||
saveNodePosition(workflow.id, step.id, position);
|
||||
}
|
||||
|
||||
const updatedConfig = addStepToWorkflow(workflow.config, step);
|
||||
|
||||
const newWorkflow = {
|
||||
...workflow,
|
||||
config: updatedConfig,
|
||||
};
|
||||
|
||||
onChange?.(newWorkflow);
|
||||
setIsDirty(true);
|
||||
messageApi.success(`Added ${step.name} to workflow`);
|
||||
},
|
||||
[workflow, onChange, messageApi]
|
||||
);
|
||||
|
||||
const handleDeleteStep = useCallback(
|
||||
(stepId: string) => {
|
||||
const updatedConfig = {
|
||||
...workflow.config,
|
||||
steps: workflow.config.steps.filter((s) => s.id !== stepId),
|
||||
edges: workflow.config.edges.filter(
|
||||
(e) => e.from_step !== stepId && e.to_step !== stepId
|
||||
),
|
||||
start_step_id:
|
||||
workflow.config.start_step_id === stepId
|
||||
? undefined
|
||||
: workflow.config.start_step_id,
|
||||
};
|
||||
|
||||
const stepName =
|
||||
workflow.config.steps.find((s) => s.id === stepId)?.name || "Step";
|
||||
|
||||
if (workflow.id) {
|
||||
removeNodePosition(workflow.id, stepId);
|
||||
}
|
||||
|
||||
onChange?.({
|
||||
...workflow,
|
||||
config: updatedConfig,
|
||||
});
|
||||
setIsDirty(true);
|
||||
messageApi.success(`Removed ${stepName} from workflow`);
|
||||
},
|
||||
[workflow, onChange, messageApi]
|
||||
);
|
||||
|
||||
const onEdgesDelete = useCallback(
|
||||
(edgesToDelete: Edge[]) => {
|
||||
const edgeIdsToDelete = new Set(edgesToDelete.map((edge) => edge.id));
|
||||
const updatedConfig = {
|
||||
...workflow.config,
|
||||
edges: workflow.config.edges.filter(
|
||||
(edge) => !edgeIdsToDelete.has(edge.id)
|
||||
),
|
||||
};
|
||||
|
||||
onChange?.({
|
||||
...workflow,
|
||||
config: updatedConfig,
|
||||
});
|
||||
|
||||
setIsDirty(true);
|
||||
messageApi.success(`Deleted ${edgesToDelete.length} edge(s)`);
|
||||
},
|
||||
[workflow, onChange, messageApi]
|
||||
);
|
||||
|
||||
// Initialize nodes and edges from workflow - positions come from localStorage
|
||||
useEffect(() => {
|
||||
const flowNodes = convertToReactFlowNodes(
|
||||
workflow.config,
|
||||
workflow.id || "temp",
|
||||
handleDeleteStep
|
||||
);
|
||||
const flowEdges = convertToReactFlowEdges(workflow.config, edgeType);
|
||||
setNodes(flowNodes);
|
||||
setEdges(flowEdges);
|
||||
}, [
|
||||
workflow.config,
|
||||
workflow.id,
|
||||
edgeType,
|
||||
setNodes,
|
||||
setEdges,
|
||||
handleDeleteStep,
|
||||
]);
|
||||
|
||||
const handleSaveWorkflow = useCallback(async () => {
|
||||
try {
|
||||
await onSave?.(workflow);
|
||||
setIsDirty(false);
|
||||
messageApi.success("Workflow saved successfully");
|
||||
} catch (error) {
|
||||
console.error("Error saving workflow:", error);
|
||||
messageApi.error("Failed to save workflow");
|
||||
}
|
||||
}, [workflow, onSave, messageApi]);
|
||||
|
||||
const handleRunWorkflow = useCallback(() => {
|
||||
messageApi.info("Workflow execution coming soon!");
|
||||
}, [messageApi]);
|
||||
|
||||
const handleLayoutNodes = useCallback(() => {
|
||||
if (workflow.id) {
|
||||
workflow.config.steps.forEach((step, index) => {
|
||||
const position = calculateNodePosition(
|
||||
index,
|
||||
workflow.config.steps.length
|
||||
);
|
||||
saveNodePosition(workflow.id!, step.id, position);
|
||||
});
|
||||
|
||||
const flowNodes = convertToReactFlowNodes(
|
||||
workflow.config,
|
||||
workflow.id,
|
||||
handleDeleteStep
|
||||
);
|
||||
setNodes(flowNodes);
|
||||
|
||||
messageApi.success("Nodes arranged automatically");
|
||||
}
|
||||
}, [workflow, messageApi, handleDeleteStep, setNodes]);
|
||||
|
||||
const handleStepClick = useCallback((step: Step) => {
|
||||
setSelectedStep(step);
|
||||
setStepDrawerOpen(true);
|
||||
}, []);
|
||||
|
||||
const handleEdgeTypeChange = useCallback(
|
||||
(newEdgeType: string) => {
|
||||
setEdgeType(newEdgeType);
|
||||
setEdges((currentEdges) =>
|
||||
currentEdges.map((edge) => ({
|
||||
...edge,
|
||||
type: newEdgeType,
|
||||
}))
|
||||
);
|
||||
setIsDirty(true);
|
||||
},
|
||||
[setEdges]
|
||||
);
|
||||
|
||||
// Handle keyboard events for node deletion
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.key === "Delete" || event.key === "Backspace") {
|
||||
const selectedNodes = nodes.filter((node) => node.selected);
|
||||
selectedNodes.forEach((node) => {
|
||||
handleDeleteStep(node.id);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("keydown", handleKeyDown);
|
||||
return () => {
|
||||
document.removeEventListener("keydown", handleKeyDown);
|
||||
};
|
||||
}, [nodes, handleDeleteStep]);
|
||||
|
||||
return (
|
||||
<div className="h-full flex">
|
||||
{contextHolder}
|
||||
|
||||
{/* Main Canvas */}
|
||||
<div className="flex-1 relative">
|
||||
<ReactFlow
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onConnect={onConnect}
|
||||
onNodeDragStop={onNodeDragStop}
|
||||
onEdgesDelete={onEdgesDelete}
|
||||
nodeTypes={nodeTypes}
|
||||
fitView
|
||||
minZoom={0.1}
|
||||
maxZoom={2}
|
||||
defaultEdgeOptions={{
|
||||
animated: false,
|
||||
style: { stroke: "#6b7280", strokeWidth: 2 },
|
||||
}}
|
||||
>
|
||||
<Background
|
||||
variant={BackgroundVariant.Dots}
|
||||
gap={20}
|
||||
size={1}
|
||||
color="#e0e0e0"
|
||||
/>
|
||||
|
||||
{showMiniMap && (
|
||||
<MiniMap
|
||||
nodeStrokeColor="#666"
|
||||
nodeColor="#fff"
|
||||
maskColor="rgba(0,0,0,0.1)"
|
||||
position="bottom-right"
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Top Panel - Workflow Info */}
|
||||
<Panel
|
||||
position="top-left"
|
||||
className="bg-primary rounded border border-secondary p-2 shadow-sm"
|
||||
>
|
||||
<div className="text-sm">
|
||||
<div className="font-medium text-primary">
|
||||
{workflow.name || "Workflow"}
|
||||
</div>
|
||||
<div className="text-secondary text-xs">
|
||||
Sequential execution with data flow between steps
|
||||
</div>
|
||||
<div className="text-xs text-secondary mt-1">
|
||||
{workflow.config.steps.length} steps,{" "}
|
||||
{workflow.config.edges?.length || 0} connections
|
||||
</div>
|
||||
</div>
|
||||
</Panel>
|
||||
</ReactFlow>
|
||||
|
||||
{/* Toolbar */}
|
||||
<Toolbar
|
||||
isDirty={isDirty}
|
||||
onSave={handleSaveWorkflow}
|
||||
onRun={handleRunWorkflow}
|
||||
onAutoLayout={handleLayoutNodes}
|
||||
onToggleMiniMap={() => setShowMiniMap(!showMiniMap)}
|
||||
onToggleGrid={() => setShowGrid(!showGrid)}
|
||||
showMiniMap={showMiniMap}
|
||||
showGrid={showGrid}
|
||||
disabled={workflow.config.steps.length === 0}
|
||||
edgeType={edgeType}
|
||||
onEdgeTypeChange={handleEdgeTypeChange}
|
||||
/>
|
||||
|
||||
{/* Empty State */}
|
||||
{workflow.config.steps.length === 0 && (
|
||||
<div className="absolute inset-0 flex items-center justify-center pointer-events-none">
|
||||
<div className="text-center text-secondary">
|
||||
<div className="text-lg font-medium mb-2">
|
||||
Start Building Your Workflow
|
||||
</div>
|
||||
<div className="text-sm">
|
||||
Click steps from the library to add them to your workflow
|
||||
</div>
|
||||
{isLibraryCompact && (
|
||||
<div className="text-xs text-accent mt-2">
|
||||
Expand the library to browse all available steps
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Agent Library Sidebar */}
|
||||
<div
|
||||
className={`${
|
||||
isLibraryCompact ? "w-12" : "w-80"
|
||||
} border-l border-secondary bg-tertiary transition-all duration-200`}
|
||||
>
|
||||
<StepLibrary
|
||||
onAddStep={handleAddStep}
|
||||
isCompact={isLibraryCompact}
|
||||
onToggleCompact={() => setIsLibraryCompact(!isLibraryCompact)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Agent Details Drawer */}
|
||||
<Drawer
|
||||
title={selectedStep?.name}
|
||||
placement="right"
|
||||
width={400}
|
||||
open={stepDrawerOpen}
|
||||
onClose={() => setStepDrawerOpen(false)}
|
||||
>
|
||||
{selectedStep && (
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label className="text-sm font-medium">Description</label>
|
||||
<div className="text-sm text-secondary mt-1">
|
||||
{selectedStep.description}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="text-sm font-medium">System Message</label>
|
||||
<div className="text-sm text-secondary mt-1 p-2 bg-secondary rounded">
|
||||
{selectedStep.system_message}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="text-sm font-medium">Model</label>
|
||||
<div className="text-sm text-secondary mt-1">
|
||||
{selectedStep.model}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="text-sm font-medium">Tools</label>
|
||||
<div className="flex flex-wrap gap-1 mt-1">
|
||||
{selectedStep.tools?.map((tool, index) => (
|
||||
<span
|
||||
key={index}
|
||||
className="px-2 py-1 text-xs bg-accent/10 text-accent rounded"
|
||||
>
|
||||
{tool}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="pt-4 border-t">
|
||||
<Button type="primary" className="w-full">
|
||||
Test Step
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</Drawer>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default WorkflowBuilder;
|
||||
@@ -0,0 +1,7 @@
|
||||
export { default as FoundryManager } from "./manager";
|
||||
export { default as FoundrySidebar } from "./sidebar";
|
||||
export { default as FoundryBuilder } from "./builder";
|
||||
export { default as NewWorkflowControls } from "./newworkflow";
|
||||
export * from "./types";
|
||||
export * from "./utils";
|
||||
export { foundryAPI } from "./api";
|
||||
@@ -0,0 +1,128 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { Input, Card, Tag, Tooltip, Spin } from "antd";
|
||||
import { ChevronsRight, ChevronsLeft, Search } from "lucide-react";
|
||||
import { Step } from "./types";
|
||||
import { workflowAPI } from "./api";
|
||||
|
||||
interface StepLibraryProps {
|
||||
onAddStep: (step: Step) => void;
|
||||
isCompact: boolean;
|
||||
onToggleCompact: () => void;
|
||||
}
|
||||
|
||||
export const StepLibrary: React.FC<StepLibraryProps> = ({
|
||||
onAddStep,
|
||||
isCompact,
|
||||
onToggleCompact,
|
||||
}) => {
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
const [steps, setSteps] = useState<Step[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchSteps = async () => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
const response = await workflowAPI.getSteps();
|
||||
if (response.success) {
|
||||
setSteps(response.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch steps", error);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
fetchSteps();
|
||||
}, []);
|
||||
|
||||
const filteredSteps = steps.filter((step) =>
|
||||
step.name.toLowerCase().includes(searchTerm.toLowerCase())
|
||||
);
|
||||
|
||||
if (isCompact) {
|
||||
return (
|
||||
<div className="p-2 flex flex-col items-center">
|
||||
<Tooltip title="Expand Library" placement="left">
|
||||
<button
|
||||
onClick={onToggleCompact}
|
||||
className="p-2 text-secondary hover:text-primary rounded hover:bg-secondary/20 mb-3"
|
||||
>
|
||||
<ChevronsLeft size={16} />
|
||||
</button>
|
||||
</Tooltip>
|
||||
<div className="space-y-2">
|
||||
{steps?.slice(0, 5).map((step) => (
|
||||
<Tooltip key={step.id} title={`Add ${step.name}`} placement="left">
|
||||
<button
|
||||
onClick={() => onAddStep(step)}
|
||||
className="w-8 h-8 bg-secondary rounded text-primary hover:bg-accent transition-colors flex items-center justify-center text-xs font-medium"
|
||||
>
|
||||
{step.name.charAt(0).toUpperCase()}
|
||||
</button>
|
||||
</Tooltip>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-4 h-full flex flex-col">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<h3 className="text-lg font-semibold text-primary">Step Library</h3>
|
||||
<Tooltip title="Collapse Library">
|
||||
<button
|
||||
onClick={onToggleCompact}
|
||||
className="p-2 text-secondary hover:text-primary rounded hover:bg-secondary/20"
|
||||
>
|
||||
<ChevronsRight size={16} />
|
||||
</button>
|
||||
</Tooltip>
|
||||
</div>
|
||||
<div className="relative mb-4">
|
||||
<Input
|
||||
placeholder="Search steps..."
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
prefix={<Search size={14} className="text-secondary mr-2" />}
|
||||
className="bg-tertiary border-secondary"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex-1 overflow-y-auto pr-2">
|
||||
{isLoading ? (
|
||||
<div className="flex justify-center mt-4">
|
||||
<Spin />
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-2">
|
||||
{filteredSteps.map((step) => (
|
||||
<Card
|
||||
key={step.id}
|
||||
hoverable
|
||||
className="bg-secondary border-secondary shadow-sm hover:shadow-md transition-shadow"
|
||||
onClick={() => onAddStep(step)}
|
||||
size="small"
|
||||
>
|
||||
<div className="font-semibold text-primary">{step.name}</div>
|
||||
<p className="text-xs text-secondary mt-1 mb-2">
|
||||
{step.description}
|
||||
</p>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
<Tag color="blue" className="text-xs">
|
||||
{step.model || "Default Model"}
|
||||
</Tag>
|
||||
{step.tools && step.tools.length > 0 && (
|
||||
<Tag color="geekblue" className="text-xs">
|
||||
{step.tools.length} tools
|
||||
</Tag>
|
||||
)}
|
||||
</div>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,272 @@
|
||||
import React, { useCallback, useEffect, useState, useContext } from "react";
|
||||
import { message, Modal } from "antd";
|
||||
import { ChevronRight } from "lucide-react";
|
||||
import { appContext } from "../../../hooks/provider";
|
||||
import { workflowAPI } from "./api";
|
||||
import { WorkflowSidebar } from "./sidebar";
|
||||
import { Workflow } from "./types";
|
||||
import WorkflowBuilder from "./builder";
|
||||
|
||||
export const WorkflowManager: React.FC = () => {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [workflows, setWorkflows] = useState<Workflow[]>([]);
|
||||
const [currentWorkflow, setCurrentWorkflow] = useState<Workflow | null>(null);
|
||||
const [isSidebarOpen, setIsSidebarOpen] = useState(() => {
|
||||
if (typeof window !== "undefined") {
|
||||
const stored = localStorage.getItem("workflowSidebar");
|
||||
return stored !== null ? JSON.parse(stored) : true;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
const { user } = useContext(appContext);
|
||||
const [messageApi, contextHolder] = message.useMessage();
|
||||
const [hasUnsavedChanges, setHasUnsavedChanges] = useState(false);
|
||||
|
||||
// Persist sidebar state
|
||||
useEffect(() => {
|
||||
if (typeof window !== "undefined") {
|
||||
localStorage.setItem("workflowSidebar", JSON.stringify(isSidebarOpen));
|
||||
}
|
||||
}, [isSidebarOpen]);
|
||||
|
||||
const fetchWorkflows = useCallback(async () => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
const response = await workflowAPI.getWorkflows();
|
||||
if (response.success && response.data) {
|
||||
setWorkflows(response.data);
|
||||
if (!currentWorkflow && response.data.length > 0) {
|
||||
setCurrentWorkflow(response.data[0]);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching workflows:", error);
|
||||
messageApi.error("Failed to load workflows");
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [currentWorkflow, messageApi]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchWorkflows();
|
||||
}, [fetchWorkflows]);
|
||||
|
||||
// Handle URL params
|
||||
useEffect(() => {
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
const workflowId = params.get("workflowId");
|
||||
|
||||
if (workflowId && !currentWorkflow) {
|
||||
handleSelectWorkflow({ id: workflowId } as Workflow);
|
||||
}
|
||||
}, [currentWorkflow]);
|
||||
|
||||
const handleSelectWorkflow = async (selectedWorkflow: Workflow) => {
|
||||
if (!selectedWorkflow.id) return;
|
||||
|
||||
if (hasUnsavedChanges) {
|
||||
Modal.confirm({
|
||||
title: "Unsaved Changes",
|
||||
content: "You have unsaved changes. Do you want to discard them?",
|
||||
okText: "Discard",
|
||||
cancelText: "Go Back",
|
||||
onOk: () => {
|
||||
switchToWorkflow(selectedWorkflow.id);
|
||||
},
|
||||
});
|
||||
} else {
|
||||
await switchToWorkflow(selectedWorkflow.id);
|
||||
}
|
||||
};
|
||||
|
||||
const switchToWorkflow = async (workflowId: string) => {
|
||||
if (!workflowId) return;
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const response = await workflowAPI.getWorkflow(workflowId);
|
||||
if (response.success && response.data) {
|
||||
setCurrentWorkflow(response.data);
|
||||
window.history.pushState({}, "", `?workflowId=${workflowId}`);
|
||||
setHasUnsavedChanges(false);
|
||||
} else {
|
||||
messageApi.error(response.message || "Failed to load workflow");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error loading workflow:", error);
|
||||
messageApi.error("Failed to load workflow");
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDeleteWorkflow = async (workflowId: string) => {
|
||||
try {
|
||||
const response = await workflowAPI.deleteWorkflow(workflowId);
|
||||
if (response.success) {
|
||||
setWorkflows(workflows.filter((w) => w.id !== workflowId));
|
||||
if (currentWorkflow?.id === workflowId) {
|
||||
setCurrentWorkflow(
|
||||
workflows.find((w) => w.id !== workflowId) || null
|
||||
);
|
||||
}
|
||||
messageApi.success("Workflow deleted");
|
||||
} else {
|
||||
messageApi.error(response.message || "Failed to delete workflow");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error deleting workflow:", error);
|
||||
messageApi.error("Error deleting workflow");
|
||||
}
|
||||
};
|
||||
|
||||
const handleCreateWorkflow = async () => {
|
||||
try {
|
||||
const name = "New Workflow";
|
||||
const response = await workflowAPI.createWorkflow({
|
||||
name,
|
||||
description: "A new workflow.",
|
||||
config: {
|
||||
id: `config-${Date.now()}`,
|
||||
name,
|
||||
description: "A new workflow.",
|
||||
steps: [],
|
||||
edges: [],
|
||||
},
|
||||
});
|
||||
|
||||
if (response.success && response.data) {
|
||||
const newWorkflow = response.data;
|
||||
setWorkflows([newWorkflow, ...workflows]);
|
||||
setCurrentWorkflow(newWorkflow);
|
||||
messageApi.success("Workflow created successfully");
|
||||
} else {
|
||||
messageApi.error(response.message || "Failed to create workflow");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error creating workflow:", error);
|
||||
messageApi.error("Error creating workflow");
|
||||
}
|
||||
};
|
||||
|
||||
const handleWorkflowChange = (workflowData: Partial<Workflow>) => {
|
||||
if (!currentWorkflow) return;
|
||||
|
||||
const updatedWorkflow = {
|
||||
...currentWorkflow,
|
||||
...workflowData,
|
||||
};
|
||||
|
||||
setCurrentWorkflow(updatedWorkflow);
|
||||
setHasUnsavedChanges(true);
|
||||
};
|
||||
|
||||
const handleSaveWorkflow = async (workflowData: Partial<Workflow>) => {
|
||||
if (!currentWorkflow?.id) return;
|
||||
|
||||
try {
|
||||
const response = await workflowAPI.updateWorkflow(currentWorkflow.id, {
|
||||
id: currentWorkflow.id,
|
||||
name: workflowData.name,
|
||||
description: workflowData.description,
|
||||
config: workflowData.config,
|
||||
});
|
||||
|
||||
if (response.success && response.data) {
|
||||
const savedWorkflow = response.data;
|
||||
setWorkflows(
|
||||
workflows.map((w) => (w.id === savedWorkflow.id ? savedWorkflow : w))
|
||||
);
|
||||
setCurrentWorkflow(savedWorkflow);
|
||||
setHasUnsavedChanges(false);
|
||||
messageApi.success("Workflow saved successfully");
|
||||
} else {
|
||||
messageApi.error(response.message || "Failed to save workflow");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error saving workflow:", error);
|
||||
messageApi.error("Error saving workflow");
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="relative flex h-full w-full">
|
||||
{contextHolder}
|
||||
|
||||
{/* Sidebar */}
|
||||
<div
|
||||
className={`absolute left-0 top-0 h-full transition-all duration-200 ease-in-out z-10 ${
|
||||
isSidebarOpen ? "w-64" : "w-12"
|
||||
}`}
|
||||
>
|
||||
<WorkflowSidebar
|
||||
isOpen={isSidebarOpen}
|
||||
workflows={workflows}
|
||||
currentWorkflow={currentWorkflow}
|
||||
onToggle={() => setIsSidebarOpen(!isSidebarOpen)}
|
||||
onSelectWorkflow={handleSelectWorkflow}
|
||||
onCreateWorkflow={handleCreateWorkflow}
|
||||
onDeleteWorkflow={handleDeleteWorkflow}
|
||||
isLoading={isLoading}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Main Content */}
|
||||
<div
|
||||
className={`flex-1 transition-all duration-200 ${
|
||||
isSidebarOpen ? "ml-64" : "ml-12"
|
||||
}`}
|
||||
>
|
||||
<div className="p-4 pt-2 h-full">
|
||||
{/* Breadcrumb */}
|
||||
<div className="flex items-center gap-2 mb-4 text-sm">
|
||||
<span className="text-primary font-medium">Workflows</span>
|
||||
{currentWorkflow && (
|
||||
<>
|
||||
<ChevronRight className="w-4 h-4 text-secondary" />
|
||||
<span className="text-secondary">
|
||||
{currentWorkflow.name}
|
||||
{!currentWorkflow.id && (
|
||||
<span className="text-xs text-orange-500"> (New)</span>
|
||||
)}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Content Area */}
|
||||
{currentWorkflow ? (
|
||||
<div className="h-[calc(100vh-120px)]">
|
||||
<WorkflowBuilder
|
||||
workflow={currentWorkflow}
|
||||
onChange={handleWorkflowChange}
|
||||
onSave={handleSaveWorkflow}
|
||||
onDirtyStateChange={setHasUnsavedChanges}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center justify-center h-[calc(100vh-120px)] text-secondary">
|
||||
<div className="text-center">
|
||||
<h3 className="text-lg font-medium mb-2">Welcome</h3>
|
||||
<p className="text-sm mb-4">
|
||||
Select a workflow from the sidebar or create a new one
|
||||
</p>
|
||||
<div className="flex gap-2 justify-center">
|
||||
<button
|
||||
onClick={() => handleCreateWorkflow()}
|
||||
className="px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 transition-colors"
|
||||
>
|
||||
Create Workflow
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default WorkflowManager;
|
||||
@@ -0,0 +1,38 @@
|
||||
import React from "react";
|
||||
import { Button } from "antd";
|
||||
import { Plus, GitBranch } from "lucide-react";
|
||||
|
||||
interface NewWorkflowControlsProps {
|
||||
isLoading: boolean;
|
||||
onCreateWorkflow: () => void;
|
||||
}
|
||||
|
||||
const NewWorkflowControls = ({
|
||||
isLoading,
|
||||
onCreateWorkflow,
|
||||
}: NewWorkflowControlsProps) => {
|
||||
const handleCreateWorkflow = async () => {
|
||||
await onCreateWorkflow();
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-2 w-full">
|
||||
<Button
|
||||
type="primary"
|
||||
className="w-full"
|
||||
onClick={handleCreateWorkflow}
|
||||
disabled={isLoading}
|
||||
icon={<Plus className="w-4 h-4" />}
|
||||
>
|
||||
New Workflow
|
||||
</Button>
|
||||
|
||||
<div className="text-xs text-secondary flex items-center justify-center gap-1">
|
||||
<GitBranch className="w-3 h-3" />
|
||||
<span>Graph-based workflow</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default NewWorkflowControls;
|
||||
@@ -0,0 +1,93 @@
|
||||
import React from "react";
|
||||
import { Handle, Position, NodeProps, Node } from "@xyflow/react";
|
||||
import { Bot, X } from "lucide-react";
|
||||
import { NodeData } from "./types";
|
||||
|
||||
type StepNodeType = Node<NodeData>;
|
||||
|
||||
export const StepNode: React.FC<NodeProps<StepNodeType>> = ({
|
||||
data,
|
||||
selected,
|
||||
id,
|
||||
}) => {
|
||||
const { step, onDelete } = data;
|
||||
|
||||
if (!step) {
|
||||
return (
|
||||
<div className="p-4 border border-red-500 bg-red-50 rounded">
|
||||
<div className="text-red-600">Error: Step data not found</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const handleDelete = (e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
onDelete?.(id);
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`
|
||||
group relative w-[220px] bg-primary rounded-lg border-2 shadow-sm transition-all
|
||||
${selected ? "border-accent shadow-lg" : "border-secondary"}
|
||||
cursor-pointer
|
||||
`}
|
||||
>
|
||||
<Handle
|
||||
type="target"
|
||||
position={Position.Left}
|
||||
className="!bg-accent !w-2 !h-5 !rounded-r-sm !-ml-1 !border-0 hover:!bg-accent/80 transition-colors"
|
||||
/>
|
||||
<Handle
|
||||
type="source"
|
||||
position={Position.Right}
|
||||
className="!bg-accent !w-2 !h-5 !rounded-l-sm !-mr-1 !border-0 hover:!bg-accent/80 transition-colors"
|
||||
/>
|
||||
|
||||
<div className="p-3">
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<Bot className="w-4 h-4 text-accent flex-shrink-0" />
|
||||
<span
|
||||
className="font-medium text-sm truncate flex-1 text-primary"
|
||||
title={step.name}
|
||||
>
|
||||
{step.name}
|
||||
</span>
|
||||
<button
|
||||
onClick={handleDelete}
|
||||
className="opacity-0 group-hover:opacity-100 transition-opacity text-secondary hover:text-red-500"
|
||||
aria-label="Delete step"
|
||||
>
|
||||
<X size={14} />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div
|
||||
className="text-xs text-secondary mb-2 overflow-hidden"
|
||||
style={{
|
||||
display: "-webkit-box",
|
||||
WebkitLineClamp: 2,
|
||||
WebkitBoxOrient: "vertical",
|
||||
height: "2rem",
|
||||
lineHeight: "1rem",
|
||||
}}
|
||||
title={step.description}
|
||||
>
|
||||
{step.description}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-2 text-xs">
|
||||
<span className="px-1.5 py-0.5 bg-accent/10 text-accent rounded text-xs truncate flex-1">
|
||||
{step.model || "Not specified"}
|
||||
</span>
|
||||
|
||||
{step.tools && step.tools.length > 0 && (
|
||||
<span className="px-1.5 py-0.5 bg-secondary/50 text-secondary rounded text-xs">
|
||||
{step.tools.length} tools
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,193 @@
|
||||
import React from "react";
|
||||
import { Button, Tooltip } from "antd";
|
||||
import {
|
||||
Plus,
|
||||
Trash2,
|
||||
PanelLeftClose,
|
||||
PanelLeftOpen,
|
||||
GitBranch,
|
||||
History,
|
||||
InfoIcon,
|
||||
RefreshCcw,
|
||||
} from "lucide-react";
|
||||
import { Workflow } from "./types";
|
||||
import { getRelativeTimeString } from "../atoms";
|
||||
import { getWorkflowTypeColor } from "./utils";
|
||||
import NewWorkflowControls from "./newworkflow";
|
||||
|
||||
interface WorkflowSidebarProps {
|
||||
isOpen: boolean;
|
||||
workflows: Workflow[];
|
||||
currentWorkflow: Workflow | null;
|
||||
onToggle: () => void;
|
||||
onSelectWorkflow: (workflow: Workflow) => void;
|
||||
onCreateWorkflow: () => void;
|
||||
onDeleteWorkflow: (workflowId: string) => void;
|
||||
isLoading?: boolean;
|
||||
}
|
||||
|
||||
export const WorkflowSidebar: React.FC<WorkflowSidebarProps> = ({
|
||||
isOpen,
|
||||
workflows,
|
||||
currentWorkflow,
|
||||
onToggle,
|
||||
onSelectWorkflow,
|
||||
onCreateWorkflow,
|
||||
onDeleteWorkflow,
|
||||
isLoading = false,
|
||||
}) => {
|
||||
if (!isOpen) {
|
||||
return (
|
||||
<div className="h-full border-r border-secondary">
|
||||
<div className="p-2 -ml-2">
|
||||
<Tooltip
|
||||
title={
|
||||
<span>
|
||||
Workflows{" "}
|
||||
<span className="text-accent mx-1"> {workflows.length} </span>
|
||||
</span>
|
||||
}
|
||||
>
|
||||
<button
|
||||
onClick={onToggle}
|
||||
className="p-2 rounded-md hover:bg-secondary hover:text-accent text-secondary transition-colors focus:outline-none focus:ring-2 focus:ring-accent focus:ring-opacity-50"
|
||||
>
|
||||
<PanelLeftOpen strokeWidth={1.5} className="h-6 w-6" />
|
||||
</button>
|
||||
</Tooltip>
|
||||
</div>
|
||||
<div className="mt-4 px-2 -ml-1">
|
||||
<Tooltip title="Create new workflow">
|
||||
<Button
|
||||
type="text"
|
||||
className="w-full p-2 flex justify-center"
|
||||
onClick={() => onCreateWorkflow()}
|
||||
icon={<Plus className="w-4 h-4" />}
|
||||
/>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="h-full border-r border-secondary">
|
||||
<div className="flex items-center justify-between pt-0 p-4 pl-2 pr-2 border-b border-secondary">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-primary font-medium">Workflows</span>
|
||||
<span className="px-2 py-0.5 text-xs bg-accent/10 text-accent rounded">
|
||||
{workflows.length}
|
||||
</span>
|
||||
</div>
|
||||
<Tooltip title="Close Sidebar">
|
||||
<button
|
||||
onClick={onToggle}
|
||||
className="p-2 rounded-md hover:bg-secondary hover:text-accent text-secondary transition-colors focus:outline-none focus:ring-2 focus:ring-accent focus:ring-opacity-50"
|
||||
>
|
||||
<PanelLeftClose strokeWidth={1.5} className="h-6 w-6" />
|
||||
</button>
|
||||
</Tooltip>
|
||||
</div>
|
||||
|
||||
<div className="my-4 flex text-sm">
|
||||
<div className="mr-2 w-full pr-2">
|
||||
{isOpen && (
|
||||
<NewWorkflowControls
|
||||
isLoading={isLoading}
|
||||
onCreateWorkflow={onCreateWorkflow}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="py-2 flex text-sm text-secondary">
|
||||
<History className="w-4 h-4 inline-block mr-1.5" />
|
||||
<div className="inline-block -mt-0.5">
|
||||
Recents{" "}
|
||||
<span className="text-accent text-xs mx-1 mt-0.5">
|
||||
{" "}
|
||||
({workflows.length}){" "}
|
||||
</span>{" "}
|
||||
</div>
|
||||
|
||||
{isLoading && (
|
||||
<RefreshCcw className="w-4 h-4 inline-block ml-2 animate-spin" />
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* no workflows found */}
|
||||
{!isLoading && workflows.length === 0 && (
|
||||
<div className="p-2 mr-2 text-center text-secondary text-sm border border-dashed rounded">
|
||||
<InfoIcon className="w-4 h-4 inline-block mr-1.5 -mt-0.5" />
|
||||
No recent workflows found
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="overflow-y-auto scroll h-[calc(100%-181px)]">
|
||||
{workflows.map((workflow) => (
|
||||
<div
|
||||
key={workflow.id}
|
||||
className="relative group"
|
||||
onClick={() => onSelectWorkflow(workflow)}
|
||||
>
|
||||
<div
|
||||
className={`
|
||||
absolute top-1 left-0.5 z-10 h-[calc(100%-8px)] w-1 rounded
|
||||
${
|
||||
currentWorkflow?.id === workflow.id
|
||||
? "bg-accent"
|
||||
: "bg-transparent group-hover:bg-secondary"
|
||||
}
|
||||
`}
|
||||
/>
|
||||
<div
|
||||
className={`
|
||||
p-2 m-1 ml-2 rounded cursor-pointer transition-colors
|
||||
${
|
||||
currentWorkflow?.id === workflow.id
|
||||
? "bg-secondary"
|
||||
: "hover:bg-secondary"
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2 flex-1 min-w-0">
|
||||
<div
|
||||
className={`p-1.5 rounded ${getWorkflowTypeColor(
|
||||
"workflow"
|
||||
)}`}
|
||||
>
|
||||
<GitBranch className="w-4 h-4" />
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="font-medium text-sm truncate text-primary">
|
||||
{workflow.name}
|
||||
</div>
|
||||
<div className="text-xs text-secondary truncate">
|
||||
{getRelativeTimeString(new Date(workflow.updated_at))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="opacity-0 group-hover:opacity-100 transition-opacity">
|
||||
<Tooltip title="Delete workflow">
|
||||
<button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onDeleteWorkflow(workflow.id);
|
||||
}}
|
||||
className="p-1 rounded hover:bg-red-500/10 text-secondary hover:text-red-500"
|
||||
>
|
||||
<Trash2 className="w-3.5 h-3.5" />
|
||||
</button>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default WorkflowSidebar;
|
||||
@@ -0,0 +1,148 @@
|
||||
import React, { useState } from "react";
|
||||
import {
|
||||
Save,
|
||||
Play,
|
||||
Layout,
|
||||
Map,
|
||||
Grid,
|
||||
MoreVertical,
|
||||
Settings,
|
||||
} from "lucide-react";
|
||||
import { Button, Tooltip, Segmented, Popover } from "antd";
|
||||
|
||||
interface ToolbarProps {
|
||||
isDirty: boolean;
|
||||
onSave: () => void;
|
||||
onRun: () => void;
|
||||
onAutoLayout: () => void;
|
||||
onToggleMiniMap: () => void;
|
||||
onToggleGrid: () => void;
|
||||
showMiniMap: boolean;
|
||||
showGrid: boolean;
|
||||
disabled: boolean;
|
||||
edgeType: string;
|
||||
onEdgeTypeChange: (type: string) => void;
|
||||
}
|
||||
|
||||
export const Toolbar: React.FC<ToolbarProps> = ({
|
||||
isDirty,
|
||||
onSave,
|
||||
onRun,
|
||||
onAutoLayout,
|
||||
onToggleMiniMap,
|
||||
onToggleGrid,
|
||||
showMiniMap,
|
||||
showGrid,
|
||||
disabled,
|
||||
edgeType,
|
||||
onEdgeTypeChange,
|
||||
}) => {
|
||||
const [showAdvanced, setShowAdvanced] = useState(false);
|
||||
|
||||
// Advanced settings content
|
||||
const advancedContent = (
|
||||
<div className="flex flex-col gap-2 w-64">
|
||||
<div className="font-semibold text-sm mb-2">View Options</div>
|
||||
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm">Minimap</span>
|
||||
<Button
|
||||
type={showMiniMap ? "primary" : "default"}
|
||||
size="small"
|
||||
icon={<Map size={14} />}
|
||||
onClick={onToggleMiniMap}
|
||||
>
|
||||
{showMiniMap ? "Hide" : "Show"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm">Grid</span>
|
||||
<Button
|
||||
type={showGrid ? "primary" : "default"}
|
||||
size="small"
|
||||
icon={<Grid size={14} />}
|
||||
onClick={onToggleGrid}
|
||||
>
|
||||
{showGrid ? "Hide" : "Show"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="border-t pt-2 mt-2">
|
||||
<div className="font-semibold text-sm mb-2">Edge Style</div>
|
||||
<Segmented
|
||||
options={[
|
||||
{ label: "Smooth", value: "smoothstep" },
|
||||
{ label: "Straight", value: "straight" },
|
||||
{ label: "Step", value: "step" },
|
||||
]}
|
||||
value={edgeType}
|
||||
onChange={(value) => onEdgeTypeChange(value as string)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="border-t pt-2 mt-2">
|
||||
<div className="font-semibold text-sm mb-2">Layout</div>
|
||||
<Button
|
||||
icon={<Layout size={14} />}
|
||||
onClick={onAutoLayout}
|
||||
disabled={disabled}
|
||||
size="small"
|
||||
block
|
||||
>
|
||||
Auto-arrange Nodes
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="absolute top-4 right-4 z-10">
|
||||
{/* Main Toolbar - Vertical Layout */}
|
||||
<div className="flex flex-col gap-2">
|
||||
{/* Primary Actions */}
|
||||
<div className="flex flex-col bg-primary rounded-md border border-secondary shadow-sm overflow-hidden">
|
||||
<Tooltip title="Run Workflow" placement="left">
|
||||
<Button
|
||||
type="text"
|
||||
icon={<Play size={18} />}
|
||||
onClick={onRun}
|
||||
disabled={disabled}
|
||||
className="h-10 w-10 flex items-center justify-center"
|
||||
/>
|
||||
</Tooltip>
|
||||
|
||||
<Tooltip title="Save Workflow" placement="left">
|
||||
<Button
|
||||
type={isDirty ? "primary" : "text"}
|
||||
icon={<Save size={18} />}
|
||||
onClick={onSave}
|
||||
disabled={!isDirty}
|
||||
className="h-10 w-10 flex items-center justify-center"
|
||||
/>
|
||||
</Tooltip>
|
||||
|
||||
<Popover
|
||||
content={advancedContent}
|
||||
title="Workflow Settings"
|
||||
trigger="click"
|
||||
placement="leftTop"
|
||||
open={showAdvanced}
|
||||
onOpenChange={setShowAdvanced}
|
||||
>
|
||||
<Tooltip title="More Options" placement="left">
|
||||
<Button
|
||||
type={showAdvanced ? "primary" : "text"}
|
||||
icon={<Settings size={18} />}
|
||||
className="h-10 w-10 flex items-center justify-center"
|
||||
/>
|
||||
</Tooltip>
|
||||
</Popover>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default Toolbar;
|
||||
@@ -0,0 +1,88 @@
|
||||
// A Step in a workflow. For now, it retains agent-like properties for the UI.
|
||||
export interface Step {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
type: "agent_step"; // A more generic type
|
||||
system_message?: string;
|
||||
tools?: string[];
|
||||
model?: string;
|
||||
metadata?: Record<string, any>;
|
||||
}
|
||||
|
||||
// An Edge in the workflow graph, connecting two steps.
|
||||
export interface WorkflowEdge {
|
||||
id: string;
|
||||
from_step: string; // Source step id
|
||||
to_step: string; // Target step id
|
||||
condition?: string; // Optional condition for the edge
|
||||
}
|
||||
|
||||
// The configuration for a workflow, mirroring the backend spec.
|
||||
export interface WorkflowConfig {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
steps: Step[];
|
||||
edges: WorkflowEdge[];
|
||||
start_step_id?: string;
|
||||
end_step_ids?: string[];
|
||||
initial_state?: Record<string, any>;
|
||||
metadata?: Record<string, any>;
|
||||
}
|
||||
|
||||
// The top-level workflow object, containing the config and other metadata.
|
||||
export interface Workflow {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
config: WorkflowConfig;
|
||||
user_id?: string;
|
||||
}
|
||||
|
||||
// UI State Types
|
||||
export interface WorkflowState {
|
||||
currentWorkflow: Workflow | null;
|
||||
workflows: Workflow[];
|
||||
selectedStep: Step | null;
|
||||
isEditing: boolean;
|
||||
isLoading: boolean;
|
||||
error: string | null;
|
||||
}
|
||||
|
||||
// Data structure for nodes in React Flow
|
||||
export interface NodeData extends Record<string, unknown> {
|
||||
step: Step;
|
||||
onDelete?: (id: string) => void;
|
||||
// any other node-specific data
|
||||
}
|
||||
|
||||
// A library of reusable steps
|
||||
export interface StepLibrary {
|
||||
name: string;
|
||||
description: string;
|
||||
steps: Step[];
|
||||
}
|
||||
|
||||
// API Response Types
|
||||
export interface ApiResponse<T> {
|
||||
success: boolean;
|
||||
data: T;
|
||||
message?: string; // Optional error message
|
||||
}
|
||||
|
||||
// API Request Payloads
|
||||
export interface CreateWorkflowRequest {
|
||||
name: string;
|
||||
description: string;
|
||||
config: WorkflowConfig;
|
||||
}
|
||||
|
||||
export interface UpdateWorkflowRequest {
|
||||
id: string;
|
||||
name?: string;
|
||||
description?: string;
|
||||
config?: Partial<WorkflowConfig>;
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
import {
|
||||
FoundryWorkflowConfig,
|
||||
WorkflowConfig,
|
||||
FoundryAgentConfig,
|
||||
Position,
|
||||
FoundryNodeData,
|
||||
} from "./types";
|
||||
import { Node, Edge } from "@xyflow/react";
|
||||
|
||||
// Type guard
|
||||
export const isWorkflow = (
|
||||
config: FoundryWorkflowConfig
|
||||
): config is WorkflowConfig => {
|
||||
return config.type === "workflow";
|
||||
};
|
||||
|
||||
// Workflow utilities
|
||||
export const createEmptyWorkflow = (
|
||||
name: string,
|
||||
description: string
|
||||
): WorkflowConfig => ({
|
||||
id: `workflow-${Date.now()}`,
|
||||
name,
|
||||
description,
|
||||
type: "workflow",
|
||||
agents: [],
|
||||
edges: [],
|
||||
termination_conditions: ["max_messages:20"],
|
||||
});
|
||||
|
||||
export const addAgentToWorkflow = (
|
||||
config: WorkflowConfig,
|
||||
agent: FoundryAgentConfig
|
||||
): WorkflowConfig => {
|
||||
// Ensure no duplicate agents are added
|
||||
if (config.agents.some((a) => a.id === agent.id)) {
|
||||
return config;
|
||||
}
|
||||
return {
|
||||
...config,
|
||||
agents: [...config.agents, agent],
|
||||
};
|
||||
};
|
||||
|
||||
// Local storage keys
|
||||
const getPositionKey = (workflowId: string) =>
|
||||
`workflow-${workflowId}-positions`;
|
||||
|
||||
// Save node positions to local storage
|
||||
export const saveNodePosition = (
|
||||
workflowId: string,
|
||||
nodeId: string,
|
||||
position: { x: number; y: number }
|
||||
) => {
|
||||
const key = getPositionKey(workflowId);
|
||||
const positions = JSON.parse(localStorage.getItem(key) || "{}");
|
||||
positions[nodeId] = position;
|
||||
localStorage.setItem(key, JSON.stringify(positions));
|
||||
};
|
||||
|
||||
// Load node positions from local storage
|
||||
export const loadNodePositions = (workflowId: string) => {
|
||||
const key = getPositionKey(workflowId);
|
||||
return JSON.parse(localStorage.getItem(key) || "{}");
|
||||
};
|
||||
|
||||
// Remove a node's position from local storage
|
||||
export const removeNodePosition = (workflowId: string, nodeId: string) => {
|
||||
const key = getPositionKey(workflowId);
|
||||
const positions = JSON.parse(localStorage.getItem(key) || "{}");
|
||||
delete positions[nodeId];
|
||||
localStorage.setItem(key, JSON.stringify(positions));
|
||||
};
|
||||
|
||||
// Calculate a default position for a new node
|
||||
export const calculateNodePosition = (index: number, totalNodes: number) => {
|
||||
const x = 250 * (index % 4);
|
||||
const y = 150 * Math.floor(index / 4);
|
||||
return { x, y };
|
||||
};
|
||||
|
||||
// Convert workflow config to React Flow nodes
|
||||
export const convertToReactFlowNodes = (
|
||||
config: WorkflowConfig,
|
||||
workflowId: string,
|
||||
onDelete: (id: string) => void
|
||||
): Node<NodeData>[] => {
|
||||
const positions = loadNodePositions(workflowId);
|
||||
return config.steps.map((step, index): Node<NodeData> => {
|
||||
const position =
|
||||
positions[step.id] || calculateNodePosition(index, config.steps.length);
|
||||
return {
|
||||
id: step.id,
|
||||
type: "step",
|
||||
position,
|
||||
data: { step, onDelete },
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
// Convert workflow config to React Flow edges
|
||||
export const convertToReactFlowEdges = (
|
||||
config: WorkflowConfig,
|
||||
edgeType: string
|
||||
): Edge[] => {
|
||||
return config.edges.map((edge) => ({
|
||||
id: edge.id,
|
||||
source: edge.from_step,
|
||||
target: edge.to_step,
|
||||
type: edgeType,
|
||||
}));
|
||||
};
|
||||
|
||||
// UI layout and color utilities
|
||||
export const getWorkflowTypeColor = (type: "workflow") => {
|
||||
return "bg-blue-500/10 text-blue-500";
|
||||
};
|
||||
|
||||
// Add a new step to the workflow config
|
||||
export const addStepToWorkflow = (
|
||||
config: WorkflowConfig,
|
||||
step: Step
|
||||
): WorkflowConfig => {
|
||||
// Avoid adding duplicate steps
|
||||
if (config.steps.some((s) => s.id === step.id)) {
|
||||
return config;
|
||||
}
|
||||
return {
|
||||
...config,
|
||||
steps: [...config.steps, step],
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,30 @@
|
||||
import * as React from "react";
|
||||
import Layout from "../components/layout";
|
||||
import { graphql } from "gatsby";
|
||||
import DeployManager from "../components/views/deploy/manager";
|
||||
import LabsManager from "../components/views/labs/manager";
|
||||
import { FoundryManager } from "../components/views/workflows";
|
||||
|
||||
// markup
|
||||
const WorkflowPage = ({ data }: any) => {
|
||||
return (
|
||||
<Layout meta={data.site.siteMetadata} title="Home" link={"/labs"}>
|
||||
<main style={{ height: "100%" }} className=" h-full ">
|
||||
<FoundryManager />
|
||||
</main>
|
||||
</Layout>
|
||||
);
|
||||
};
|
||||
|
||||
export const query = graphql`
|
||||
query HomePageQuery {
|
||||
site {
|
||||
siteMetadata {
|
||||
description
|
||||
title
|
||||
}
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
export default WorkflowPage;
|
||||
@@ -19,7 +19,7 @@ classifiers = [
|
||||
|
||||
|
||||
dependencies = [
|
||||
"pydantic",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic-settings",
|
||||
"fastapi[standard]",
|
||||
"typer",
|
||||
|
||||
Reference in New Issue
Block a user