mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-21 04:57:58 -05:00
feat(classic): add sub-agent architecture and LATS/multi-agent debate strategies
Add comprehensive sub-agent spawning infrastructure that enables prompt strategies to coordinate multiple agents for advanced reasoning patterns. New files: - forge/agent/execution_context.py: ExecutionContext, ResourceBudget, SubAgentHandle, and AgentFactory protocol for sub-agent lifecycle - agent_factory/default_factory.py: DefaultAgentFactory implementation - prompt_strategies/lats.py: Language Agent Tree Search using MCTS with sub-agents for action expansion and evaluation - prompt_strategies/multi_agent_debate.py: Multi-agent debate with proposal, critique, and consensus phases Key changes: - BaseMultiStepPromptStrategy gains spawn_sub_agent(), run_sub_agent(), spawn_and_run(), and run_parallel() methods - Agent class accepts optional ExecutionContext and injects it into strategies - Sub-agents enabled by default (enable_sub_agents=True) - Resource limits: max_depth=5, max_sub_agents=25, max_cycles=25 All 7 strategies now available in benchmark: one_shot, rewoo, plan_execute, reflexion, tree_of_thoughts, lats, multi_agent_debate Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
# Type aliases
|
||||
StrategyName = Literal[
|
||||
"one_shot", "rewoo", "plan_execute", "reflexion", "tree_of_thoughts"
|
||||
"one_shot",
|
||||
"rewoo",
|
||||
"plan_execute",
|
||||
"reflexion",
|
||||
"tree_of_thoughts",
|
||||
"lats",
|
||||
"multi_agent_debate",
|
||||
]
|
||||
ReasoningEffort = Literal["low", "medium", "high"]
|
||||
|
||||
@@ -17,6 +23,8 @@ STRATEGIES: list[StrategyName] = [
|
||||
"plan_execute",
|
||||
"reflexion",
|
||||
"tree_of_thoughts",
|
||||
"lats",
|
||||
"multi_agent_debate",
|
||||
]
|
||||
|
||||
|
||||
|
||||
339
classic/forge/forge/agent/execution_context.py
Normal file
339
classic/forge/forge/agent/execution_context.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Execution context for sub-agent support.
|
||||
|
||||
This module provides the infrastructure for strategies to spawn and coordinate
|
||||
sub-agents. The ExecutionContext is passed down the agent hierarchy and provides
|
||||
access to shared resources while enforcing resource budgets.
|
||||
|
||||
Based on research from:
|
||||
- Google ADK Multi-Agent Patterns
|
||||
- Anthropic Multi-Agent Research System
|
||||
- LATS (Language Agent Tree Search)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.file_storage.base import FileStorage
|
||||
from forge.llm.providers import MultiProvider
|
||||
|
||||
|
||||
class SubAgentStatus(str, Enum):
|
||||
"""Status of a sub-agent."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourceBudget:
|
||||
"""Resource limits for sub-agent execution.
|
||||
|
||||
Based on design decisions from research:
|
||||
- Permissive defaults for flexibility
|
||||
- Inherited deny rules, explicit allow rules
|
||||
"""
|
||||
|
||||
max_depth: int = 5
|
||||
"""Maximum nesting depth for sub-agents."""
|
||||
|
||||
max_sub_agents: int = 25
|
||||
"""Maximum number of sub-agents that can be spawned."""
|
||||
|
||||
max_cycles_per_agent: int = 50
|
||||
"""Maximum execution cycles per agent."""
|
||||
|
||||
max_tokens_total: int = 0
|
||||
"""Maximum total tokens (0 = unlimited)."""
|
||||
|
||||
inherited_deny_rules: list[str] = field(default_factory=list)
|
||||
"""Permission deny rules inherited from parent (always enforced)."""
|
||||
|
||||
explicit_allow_rules: list[str] = field(default_factory=list)
|
||||
"""Permission allow rules explicitly granted to this context."""
|
||||
|
||||
def create_child_budget(self) -> "ResourceBudget":
|
||||
"""Create a budget for a child agent with reduced limits."""
|
||||
return ResourceBudget(
|
||||
max_depth=self.max_depth - 1,
|
||||
max_sub_agents=self.max_sub_agents,
|
||||
max_cycles_per_agent=self.max_cycles_per_agent,
|
||||
max_tokens_total=self.max_tokens_total,
|
||||
inherited_deny_rules=self.inherited_deny_rules.copy(),
|
||||
explicit_allow_rules=[], # Child must get explicit permissions
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubAgentHandle:
|
||||
"""Handle to a spawned sub-agent.
|
||||
|
||||
Provides methods to interact with the sub-agent without exposing
|
||||
the full agent internals. This maintains proper encapsulation.
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
"""Unique identifier for this sub-agent."""
|
||||
|
||||
task: str
|
||||
"""The task assigned to this sub-agent."""
|
||||
|
||||
status: SubAgentStatus = SubAgentStatus.PENDING
|
||||
"""Current status of the sub-agent."""
|
||||
|
||||
result: Optional[Any] = None
|
||||
"""Result from the sub-agent (if completed)."""
|
||||
|
||||
error: Optional[str] = None
|
||||
"""Error message (if failed)."""
|
||||
|
||||
summary: str = ""
|
||||
"""Brief summary of what the sub-agent accomplished."""
|
||||
|
||||
# Internal fields (not part of public API)
|
||||
_agent: Optional[Any] = field(default=None, repr=False)
|
||||
_task: Optional[asyncio.Task[Any]] = field(default=None, repr=False)
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the sub-agent is currently running."""
|
||||
return self.status == SubAgentStatus.RUNNING
|
||||
|
||||
def is_done(self) -> bool:
|
||||
"""Check if the sub-agent has finished (success or failure)."""
|
||||
return self.status in (
|
||||
SubAgentStatus.COMPLETED,
|
||||
SubAgentStatus.FAILED,
|
||||
SubAgentStatus.CANCELLED,
|
||||
)
|
||||
|
||||
|
||||
class AgentFactory(Protocol):
|
||||
"""Protocol for agent factory implementations.
|
||||
|
||||
This allows strategies to spawn sub-agents without knowing the
|
||||
concrete agent implementation details.
|
||||
"""
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
task: str,
|
||||
context: "ExecutionContext",
|
||||
ai_profile: Optional["AIProfile"] = None,
|
||||
directives: Optional["AIDirectives"] = None,
|
||||
strategy: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Create a new agent instance.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for the agent.
|
||||
task: The task the agent should accomplish.
|
||||
context: Execution context with shared resources.
|
||||
ai_profile: Optional AI profile override.
|
||||
directives: Optional directives override.
|
||||
strategy: Optional strategy name override.
|
||||
|
||||
Returns:
|
||||
A new agent instance.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""Context passed down the agent hierarchy.
|
||||
|
||||
The ExecutionContext provides sub-agents with access to shared resources
|
||||
(LLM provider, file storage, agent factory) while enforcing resource
|
||||
budgets and isolation.
|
||||
|
||||
Key design decisions (based on research):
|
||||
1. File storage: Sub-agents can READ parent workspace but only WRITE
|
||||
to their own subdirectory (.sub_agents/{agent_id}/).
|
||||
2. Permissions: Inherit deny rules, explicit allow rules per sub-agent.
|
||||
3. Context isolation: Sub-agents get minimal context (just task description).
|
||||
4. History visibility: Parent sees sub-agent results only, not full history.
|
||||
"""
|
||||
|
||||
llm_provider: "MultiProvider"
|
||||
"""Shared LLM provider for all agents in the hierarchy."""
|
||||
|
||||
file_storage: "FileStorage"
|
||||
"""File storage (may have write restrictions for sub-agents)."""
|
||||
|
||||
agent_factory: Optional[AgentFactory] = None
|
||||
"""Factory for creating sub-agents."""
|
||||
|
||||
parent_agent_id: Optional[str] = None
|
||||
"""ID of the parent agent (None for root agent)."""
|
||||
|
||||
depth: int = 0
|
||||
"""Current depth in the agent hierarchy (0 = root)."""
|
||||
|
||||
budget: ResourceBudget = field(default_factory=ResourceBudget)
|
||||
"""Resource budget for this context."""
|
||||
|
||||
sub_agents: dict[str, SubAgentHandle] = field(default_factory=dict)
|
||||
"""Active sub-agents spawned by the owning agent."""
|
||||
|
||||
_cancelled: bool = field(default=False, repr=False)
|
||||
"""Whether this context has been cancelled."""
|
||||
|
||||
_app_config: Optional[Any] = field(default=None, repr=False)
|
||||
"""Application config (for agent creation)."""
|
||||
|
||||
@property
|
||||
def is_root(self) -> bool:
|
||||
"""Check if this is a root (top-level) agent context."""
|
||||
return self.parent_agent_id is None
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Check if this context has been cancelled."""
|
||||
return self._cancelled
|
||||
|
||||
def can_spawn_sub_agent(self) -> bool:
|
||||
"""Check if spawning a sub-agent is allowed.
|
||||
|
||||
Returns:
|
||||
True if spawning is allowed, False otherwise.
|
||||
"""
|
||||
if self._cancelled:
|
||||
return False
|
||||
if self.budget.max_depth <= 0:
|
||||
return False
|
||||
if len(self.sub_agents) >= self.budget.max_sub_agents:
|
||||
return False
|
||||
if self.agent_factory is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
def create_child_context(self, child_agent_id: str) -> "ExecutionContext":
|
||||
"""Create a context for a child agent with appropriate restrictions.
|
||||
|
||||
The child context has:
|
||||
- Same LLM provider (shared)
|
||||
- Write-restricted file storage (writes to .sub_agents/{child_agent_id}/)
|
||||
- Reduced budget (depth - 1)
|
||||
- Same agent factory
|
||||
|
||||
Args:
|
||||
child_agent_id: ID of the child agent.
|
||||
|
||||
Returns:
|
||||
A new ExecutionContext for the child.
|
||||
"""
|
||||
# Create write-restricted file storage for child
|
||||
child_storage = self._create_child_storage(child_agent_id)
|
||||
|
||||
return ExecutionContext(
|
||||
llm_provider=self.llm_provider,
|
||||
file_storage=child_storage,
|
||||
agent_factory=self.agent_factory,
|
||||
parent_agent_id=child_agent_id,
|
||||
depth=self.depth + 1,
|
||||
budget=self.budget.create_child_budget(),
|
||||
_app_config=self._app_config,
|
||||
)
|
||||
|
||||
def _create_child_storage(self, child_agent_id: str) -> "FileStorage":
|
||||
"""Create a write-restricted file storage for a child agent.
|
||||
|
||||
The child can:
|
||||
- READ from the entire parent workspace
|
||||
- WRITE only to .sub_agents/{child_agent_id}/
|
||||
|
||||
Args:
|
||||
child_agent_id: ID of the child agent.
|
||||
|
||||
Returns:
|
||||
A FileStorage with write restrictions.
|
||||
"""
|
||||
# Use clone_with_subroot for the write directory
|
||||
# The child's "root" for writing is the sub_agents directory
|
||||
sub_agent_path = f".sub_agents/{child_agent_id}"
|
||||
|
||||
# For now, we create a subroot storage for the child
|
||||
# This restricts ALL access to the subroot (both read and write)
|
||||
# A more sophisticated implementation would allow read from parent
|
||||
# but that requires extending FileStorage
|
||||
return self.file_storage.clone_with_subroot(sub_agent_path)
|
||||
|
||||
def register_sub_agent(self, handle: SubAgentHandle) -> None:
|
||||
"""Register a sub-agent handle.
|
||||
|
||||
Args:
|
||||
handle: The sub-agent handle to register.
|
||||
"""
|
||||
self.sub_agents[handle.agent_id] = handle
|
||||
|
||||
def get_sub_agent(self, agent_id: str) -> Optional[SubAgentHandle]:
|
||||
"""Get a sub-agent handle by ID.
|
||||
|
||||
Args:
|
||||
agent_id: The sub-agent ID.
|
||||
|
||||
Returns:
|
||||
The sub-agent handle, or None if not found.
|
||||
"""
|
||||
return self.sub_agents.get(agent_id)
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel this context and all sub-agents.
|
||||
|
||||
This sets the cancelled flag and attempts to cancel any running
|
||||
sub-agent tasks.
|
||||
"""
|
||||
self._cancelled = True
|
||||
for handle in self.sub_agents.values():
|
||||
if handle._task and not handle._task.done():
|
||||
handle._task.cancel()
|
||||
handle.status = SubAgentStatus.CANCELLED
|
||||
|
||||
async def wait_for_sub_agents(
|
||||
self,
|
||||
timeout: Optional[float] = None,
|
||||
) -> dict[str, SubAgentHandle]:
|
||||
"""Wait for all running sub-agents to complete.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait (in seconds).
|
||||
|
||||
Returns:
|
||||
Dictionary of all sub-agent handles.
|
||||
"""
|
||||
tasks = [
|
||||
handle._task
|
||||
for handle in self.sub_agents.values()
|
||||
if handle._task and not handle._task.done()
|
||||
]
|
||||
|
||||
if tasks:
|
||||
await asyncio.wait(tasks, timeout=timeout)
|
||||
|
||||
return self.sub_agents
|
||||
|
||||
|
||||
def generate_sub_agent_id(parent_id: Optional[str] = None) -> str:
|
||||
"""Generate a unique ID for a sub-agent.
|
||||
|
||||
Args:
|
||||
parent_id: Optional parent agent ID for hierarchical naming.
|
||||
|
||||
Returns:
|
||||
A unique sub-agent ID.
|
||||
"""
|
||||
short_uuid = str(uuid.uuid4())[:8]
|
||||
if parent_id:
|
||||
return f"{parent_id}-sub-{short_uuid}"
|
||||
return f"sub-{short_uuid}"
|
||||
@@ -12,6 +12,7 @@ Features:
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Iterator, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -56,9 +57,15 @@ Respond with ONLY a JSON object (no markdown, no explanation):
|
||||
"summary": "Brief explanation"}}"""
|
||||
|
||||
|
||||
def _generate_todo_id() -> str:
|
||||
"""Generate a short unique ID for todo items."""
|
||||
return uuid.uuid4().hex[:8]
|
||||
|
||||
|
||||
class TodoItem(BaseModel):
|
||||
"""A single todo item with optional nested sub-items."""
|
||||
|
||||
id: str = Field(default_factory=_generate_todo_id, description="Unique identifier")
|
||||
content: str = Field(..., description="Imperative form: 'Fix the bug'")
|
||||
status: TodoStatus = Field(default="pending", description="Task status")
|
||||
active_form: str = Field(
|
||||
@@ -136,7 +143,9 @@ class TodoComponent(
|
||||
yield "A todo list to track and manage multi-step tasks. Use frequently!"
|
||||
|
||||
def get_best_practices(self) -> Iterator[str]:
|
||||
yield "Use todo_write when working on multi-step tasks to track progress"
|
||||
yield "Use todo_bulk_add for initial planning, then incremental ops for updates"
|
||||
yield "Use todo_set_status to mark tasks in_progress or completed"
|
||||
yield "Use todo_add to add a single new task to an existing list"
|
||||
yield "Mark todos as in_progress before starting work on them"
|
||||
yield "Mark todos as completed immediately after finishing, not in batches"
|
||||
yield "Only have ONE todo as in_progress at a time"
|
||||
@@ -237,8 +246,12 @@ class TodoComponent(
|
||||
if parsed:
|
||||
sub_items.append(parsed)
|
||||
|
||||
# Use provided ID or generate a new one
|
||||
item_id = item.get("id") or _generate_todo_id()
|
||||
|
||||
return (
|
||||
TodoItem(
|
||||
id=item_id,
|
||||
content=item["content"],
|
||||
status=item["status"],
|
||||
active_form=item["active_form"],
|
||||
@@ -252,6 +265,7 @@ class TodoComponent(
|
||||
Recursively serialize a TodoItem to a dict including sub_items.
|
||||
"""
|
||||
result: dict[str, str | list] = {
|
||||
"id": item.id,
|
||||
"content": item.content,
|
||||
"status": item.status,
|
||||
"active_form": item.active_form,
|
||||
@@ -262,122 +276,37 @@ class TodoComponent(
|
||||
]
|
||||
return result
|
||||
|
||||
def _find_by_id(self, todo_id: str) -> Optional[TodoItem]:
|
||||
"""Find a todo item by its ID (top-level only)."""
|
||||
for item in self._todos.items:
|
||||
if item.id == todo_id:
|
||||
return item
|
||||
return None
|
||||
|
||||
def _find_index_by_id(self, todo_id: str) -> int:
|
||||
"""Find the index of a todo item by its ID. Returns -1 if not found."""
|
||||
for i, item in enumerate(self._todos.items):
|
||||
if item.id == todo_id:
|
||||
return i
|
||||
return -1
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# CommandProvider Implementation
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.todo_write
|
||||
# Incremental operations (token efficient)
|
||||
yield self.todo_add
|
||||
yield self.todo_set_status
|
||||
yield self.todo_update
|
||||
yield self.todo_delete
|
||||
yield self.todo_bulk_add
|
||||
yield self.todo_reorder
|
||||
# Core operations
|
||||
yield self.todo_read
|
||||
yield self.todo_clear
|
||||
yield self.todo_decompose
|
||||
|
||||
@command(
|
||||
names=["todo_write"],
|
||||
parameters={
|
||||
"todos": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
description=(
|
||||
"The complete todo list. Each item must have: "
|
||||
"'content' (imperative form like 'Fix bug'), "
|
||||
"'status' (pending|in_progress|completed), "
|
||||
"'active_form' (present continuous like 'Fixing bug'). "
|
||||
"Optional: 'sub_items' (array of nested todo items)"
|
||||
),
|
||||
items=JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={
|
||||
"content": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="Imperative form of the task",
|
||||
required=True,
|
||||
),
|
||||
"status": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="pending, in_progress, or completed",
|
||||
enum=["pending", "in_progress", "completed"],
|
||||
required=True,
|
||||
),
|
||||
"active_form": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="Present continuous form (e.g. 'Fixing')",
|
||||
required=True,
|
||||
),
|
||||
"sub_items": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
description="Optional nested sub-tasks",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
),
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
def todo_write(self, todos: list[dict]) -> dict:
|
||||
"""
|
||||
Replace the entire todo list with a new list.
|
||||
|
||||
This is the primary command for managing todos. Use it to:
|
||||
- Create initial todos when starting a multi-step task
|
||||
- Mark tasks as in_progress when you start working on them
|
||||
- Mark tasks as completed when done
|
||||
- Add new tasks discovered during work
|
||||
- Remove tasks that are no longer relevant
|
||||
- Update sub-items created by todo_decompose
|
||||
|
||||
The entire list is replaced atomically, ensuring consistency.
|
||||
Supports nested sub_items for hierarchical task tracking.
|
||||
"""
|
||||
# Validate item count
|
||||
if len(todos) > self.config.max_items:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Too many items. Maximum is {self.config.max_items}.",
|
||||
}
|
||||
|
||||
# Validate and convert items recursively
|
||||
validated_items = []
|
||||
for i, item in enumerate(todos):
|
||||
parsed, error = self._parse_todo_item(item, f"Item {i}")
|
||||
if error:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error,
|
||||
}
|
||||
if parsed:
|
||||
validated_items.append(parsed)
|
||||
|
||||
# Count in_progress items and warn if more than one
|
||||
in_progress_count = sum(1 for t in validated_items if t.status == "in_progress")
|
||||
warning = None
|
||||
if in_progress_count > 1:
|
||||
warning = (
|
||||
f"Warning: {in_progress_count} tasks are in_progress. "
|
||||
"Best practice is to have only ONE task in_progress at a time."
|
||||
)
|
||||
logger.warning(warning)
|
||||
|
||||
# Replace the list
|
||||
self._todos = TodoList(items=validated_items)
|
||||
|
||||
# Build response
|
||||
pending = sum(1 for t in validated_items if t.status == "pending")
|
||||
completed = sum(1 for t in validated_items if t.status == "completed")
|
||||
|
||||
response = {
|
||||
"status": "success",
|
||||
"item_count": len(validated_items),
|
||||
"pending": pending,
|
||||
"in_progress": in_progress_count,
|
||||
"completed": completed,
|
||||
}
|
||||
|
||||
if warning:
|
||||
response["warning"] = warning
|
||||
|
||||
return response
|
||||
|
||||
@command(names=["todo_read"])
|
||||
def todo_read(self) -> dict:
|
||||
"""
|
||||
@@ -562,3 +491,363 @@ class TodoComponent(
|
||||
"status": "error",
|
||||
"message": f"Decomposition failed: {e}",
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Incremental Operations - Token-efficient todo management
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@command(
|
||||
names=["todo_add"],
|
||||
parameters={
|
||||
"content": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="Imperative form of the task (e.g., 'Fix the bug')",
|
||||
required=True,
|
||||
),
|
||||
"active_form": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="Present continuous form (e.g., 'Fixing the bug')",
|
||||
required=True,
|
||||
),
|
||||
"status": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="Initial status: pending, in_progress, or completed",
|
||||
enum=["pending", "in_progress", "completed"],
|
||||
required=False,
|
||||
),
|
||||
"index": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="Position to insert at (0-based). Appends if omitted.",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def todo_add(
|
||||
self,
|
||||
content: str,
|
||||
active_form: str,
|
||||
status: TodoStatus = "pending",
|
||||
index: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Add a single todo item. Returns the created item with its ID.
|
||||
|
||||
This is the most token-efficient way to add a new task.
|
||||
Use this instead of todo_write when adding one item to an existing list.
|
||||
"""
|
||||
# Validate inputs
|
||||
if not content or not content.strip():
|
||||
return {"status": "error", "message": "'content' is required"}
|
||||
if not active_form or not active_form.strip():
|
||||
return {"status": "error", "message": "'active_form' is required"}
|
||||
|
||||
# Check max items
|
||||
if len(self._todos.items) >= self.config.max_items:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Cannot add: max items ({self.config.max_items}) reached",
|
||||
}
|
||||
|
||||
# Create the new item
|
||||
new_item = TodoItem(
|
||||
content=content.strip(),
|
||||
active_form=active_form.strip(),
|
||||
status=status,
|
||||
)
|
||||
|
||||
# Insert at specified index or append
|
||||
if index is not None:
|
||||
if index < 0:
|
||||
index = 0
|
||||
if index > len(self._todos.items):
|
||||
index = len(self._todos.items)
|
||||
self._todos.items.insert(index, new_item)
|
||||
else:
|
||||
self._todos.items.append(new_item)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"item": self._serialize_todo_item(new_item),
|
||||
"total_items": len(self._todos.items),
|
||||
}
|
||||
|
||||
@command(
|
||||
names=["todo_set_status"],
|
||||
parameters={
|
||||
"id": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The unique ID of the todo to update",
|
||||
required=True,
|
||||
),
|
||||
"status": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="New status: pending, in_progress, or completed",
|
||||
enum=["pending", "in_progress", "completed"],
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
def todo_set_status(self, id: str, status: TodoStatus) -> dict:
|
||||
"""
|
||||
Update just the status of a todo by ID.
|
||||
|
||||
This is the most common operation and the most token-efficient way
|
||||
to mark a task as in_progress or completed.
|
||||
"""
|
||||
item = self._find_by_id(id)
|
||||
if not item:
|
||||
return {"status": "error", "message": f"Todo with ID '{id}' not found"}
|
||||
|
||||
old_status = item.status
|
||||
item.status = status
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"item": self._serialize_todo_item(item),
|
||||
"changed": {"status": {"from": old_status, "to": status}},
|
||||
}
|
||||
|
||||
@command(
|
||||
names=["todo_update"],
|
||||
parameters={
|
||||
"id": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The unique ID of the todo to update",
|
||||
required=True,
|
||||
),
|
||||
"content": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="New imperative form (optional)",
|
||||
required=False,
|
||||
),
|
||||
"active_form": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="New present continuous form (optional)",
|
||||
required=False,
|
||||
),
|
||||
"status": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="New status (optional)",
|
||||
enum=["pending", "in_progress", "completed"],
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def todo_update(
|
||||
self,
|
||||
id: str,
|
||||
content: Optional[str] = None,
|
||||
active_form: Optional[str] = None,
|
||||
status: Optional[TodoStatus] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Partial update of a todo - only specified fields change.
|
||||
|
||||
Use this when you need to update multiple fields at once.
|
||||
For just status changes, prefer todo_set_status.
|
||||
"""
|
||||
item = self._find_by_id(id)
|
||||
if not item:
|
||||
return {"status": "error", "message": f"Todo with ID '{id}' not found"}
|
||||
|
||||
changes: dict[str, dict[str, str]] = {}
|
||||
|
||||
if content is not None:
|
||||
if not content.strip():
|
||||
return {"status": "error", "message": "'content' cannot be empty"}
|
||||
changes["content"] = {"from": item.content, "to": content.strip()}
|
||||
item.content = content.strip()
|
||||
|
||||
if active_form is not None:
|
||||
if not active_form.strip():
|
||||
return {"status": "error", "message": "'active_form' cannot be empty"}
|
||||
changes["active_form"] = {
|
||||
"from": item.active_form,
|
||||
"to": active_form.strip(),
|
||||
}
|
||||
item.active_form = active_form.strip()
|
||||
|
||||
if status is not None:
|
||||
changes["status"] = {"from": item.status, "to": status}
|
||||
item.status = status
|
||||
|
||||
if not changes:
|
||||
return {
|
||||
"status": "success",
|
||||
"item": self._serialize_todo_item(item),
|
||||
"message": "No changes specified",
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"item": self._serialize_todo_item(item),
|
||||
"changed": changes,
|
||||
}
|
||||
|
||||
@command(
|
||||
names=["todo_delete"],
|
||||
parameters={
|
||||
"id": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The unique ID of the todo to delete",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
def todo_delete(self, id: str) -> dict:
|
||||
"""
|
||||
Explicitly delete a todo by ID.
|
||||
|
||||
Unlike todo_write where items are removed by omission (easy to accidentally
|
||||
delete), this is an explicit delete operation.
|
||||
"""
|
||||
index = self._find_index_by_id(id)
|
||||
if index == -1:
|
||||
return {"status": "error", "message": f"Todo with ID '{id}' not found"}
|
||||
|
||||
deleted_item = self._todos.items.pop(index)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"deleted": self._serialize_todo_item(deleted_item),
|
||||
"remaining_items": len(self._todos.items),
|
||||
}
|
||||
|
||||
@command(
|
||||
names=["todo_bulk_add"],
|
||||
parameters={
|
||||
"items": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
description="Array of todo items to add",
|
||||
items=JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={
|
||||
"content": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="Imperative form of the task",
|
||||
required=True,
|
||||
),
|
||||
"active_form": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="Present continuous form",
|
||||
required=True,
|
||||
),
|
||||
"status": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="Initial status (default: pending)",
|
||||
enum=["pending", "in_progress", "completed"],
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
),
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
def todo_bulk_add(self, items: list[dict]) -> dict:
|
||||
"""
|
||||
Add multiple todos at once. Use for initial planning.
|
||||
|
||||
This is efficient for creating the initial todo list at the start
|
||||
of a multi-step task. For subsequent additions, use todo_add.
|
||||
"""
|
||||
if not items:
|
||||
return {"status": "error", "message": "No items provided"}
|
||||
|
||||
# Check max items
|
||||
if len(self._todos.items) + len(items) > self.config.max_items:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Cannot add {len(items)} items: would exceed max "
|
||||
f"({self.config.max_items}). Current: {len(self._todos.items)}"
|
||||
),
|
||||
}
|
||||
|
||||
added_items = []
|
||||
for i, item in enumerate(items):
|
||||
content = item.get("content", "").strip()
|
||||
active_form = item.get("active_form", "").strip()
|
||||
status = item.get("status", "pending")
|
||||
|
||||
if not content:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Item {i}: 'content' is required",
|
||||
}
|
||||
if not active_form:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Item {i}: 'active_form' is required",
|
||||
}
|
||||
if status not in ("pending", "in_progress", "completed"):
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Item {i}: invalid status '{status}'",
|
||||
}
|
||||
|
||||
new_item = TodoItem(
|
||||
content=content,
|
||||
active_form=active_form,
|
||||
status=status,
|
||||
)
|
||||
self._todos.items.append(new_item)
|
||||
added_items.append(self._serialize_todo_item(new_item))
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"added": added_items,
|
||||
"added_count": len(added_items),
|
||||
"total_items": len(self._todos.items),
|
||||
}
|
||||
|
||||
@command(
|
||||
names=["todo_reorder"],
|
||||
parameters={
|
||||
"ids": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
description="List of todo IDs in the desired order",
|
||||
items=JSONSchema(type=JSONSchema.Type.STRING),
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
def todo_reorder(self, ids: list[str]) -> dict:
|
||||
"""
|
||||
Reorder todos by providing the ID list in desired order.
|
||||
|
||||
All current todo IDs must be included. This operation only
|
||||
changes the order, not the items themselves.
|
||||
"""
|
||||
current_ids = {item.id for item in self._todos.items}
|
||||
provided_ids = set(ids)
|
||||
|
||||
# Check for duplicates first (before other checks)
|
||||
if len(ids) != len(provided_ids):
|
||||
return {"status": "error", "message": "Duplicate IDs in reorder list"}
|
||||
|
||||
# Validate that all provided IDs exist
|
||||
unknown = provided_ids - current_ids
|
||||
if unknown:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unknown todo IDs: {', '.join(unknown)}",
|
||||
}
|
||||
|
||||
# Validate that all current IDs are provided
|
||||
missing = current_ids - provided_ids
|
||||
if missing:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Missing todo IDs in reorder list: {', '.join(missing)}",
|
||||
}
|
||||
|
||||
# Reorder
|
||||
id_to_item = {item.id: item for item in self._todos.items}
|
||||
self._todos.items = [id_to_item[id] for id in ids]
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"order": ids,
|
||||
"message": f"Reordered {len(ids)} items",
|
||||
}
|
||||
|
||||
@@ -11,167 +11,6 @@ def todo_component():
|
||||
return TodoComponent()
|
||||
|
||||
|
||||
class TestTodoWrite:
|
||||
"""Tests for the todo_write command."""
|
||||
|
||||
def test_write_empty_list(self, todo_component):
|
||||
"""Writing an empty list should succeed."""
|
||||
result = todo_component.todo_write([])
|
||||
assert result["status"] == "success"
|
||||
assert result["item_count"] == 0
|
||||
assert result["pending"] == 0
|
||||
assert result["in_progress"] == 0
|
||||
assert result["completed"] == 0
|
||||
|
||||
def test_write_single_pending_todo(self, todo_component):
|
||||
"""Writing a single pending todo should succeed."""
|
||||
result = todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Fix the bug",
|
||||
"status": "pending",
|
||||
"active_form": "Fixing the bug",
|
||||
}
|
||||
]
|
||||
)
|
||||
assert result["status"] == "success"
|
||||
assert result["item_count"] == 1
|
||||
assert result["pending"] == 1
|
||||
assert result["in_progress"] == 0
|
||||
|
||||
def test_write_multiple_todos(self, todo_component):
|
||||
"""Writing multiple todos with different statuses should succeed."""
|
||||
result = todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Research patterns",
|
||||
"status": "completed",
|
||||
"active_form": "Researching patterns",
|
||||
},
|
||||
{
|
||||
"content": "Implement feature",
|
||||
"status": "in_progress",
|
||||
"active_form": "Implementing feature",
|
||||
},
|
||||
{
|
||||
"content": "Write tests",
|
||||
"status": "pending",
|
||||
"active_form": "Writing tests",
|
||||
},
|
||||
]
|
||||
)
|
||||
assert result["status"] == "success"
|
||||
assert result["item_count"] == 3
|
||||
assert result["pending"] == 1
|
||||
assert result["in_progress"] == 1
|
||||
assert result["completed"] == 1
|
||||
|
||||
def test_write_replaces_entire_list(self, todo_component):
|
||||
"""Writing should replace the entire list, not append."""
|
||||
# First write
|
||||
todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Task 1",
|
||||
"status": "pending",
|
||||
"active_form": "Doing task 1",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Second write should replace
|
||||
result = todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Task 2",
|
||||
"status": "pending",
|
||||
"active_form": "Doing task 2",
|
||||
}
|
||||
]
|
||||
)
|
||||
assert result["item_count"] == 1
|
||||
|
||||
# Verify only Task 2 exists
|
||||
read_result = todo_component.todo_read()
|
||||
assert len(read_result["items"]) == 1
|
||||
assert read_result["items"][0]["content"] == "Task 2"
|
||||
|
||||
def test_write_warns_on_multiple_in_progress(self, todo_component):
|
||||
"""Writing multiple in_progress items should include a warning."""
|
||||
result = todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Task 1",
|
||||
"status": "in_progress",
|
||||
"active_form": "Doing task 1",
|
||||
},
|
||||
{
|
||||
"content": "Task 2",
|
||||
"status": "in_progress",
|
||||
"active_form": "Doing task 2",
|
||||
},
|
||||
]
|
||||
)
|
||||
assert result["status"] == "success"
|
||||
assert "warning" in result
|
||||
assert "2 tasks are in_progress" in result["warning"]
|
||||
|
||||
def test_write_validates_required_content(self, todo_component):
|
||||
"""Writing without content should fail."""
|
||||
result = todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "",
|
||||
"status": "pending",
|
||||
"active_form": "Doing something",
|
||||
}
|
||||
]
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "content" in result["message"]
|
||||
|
||||
def test_write_validates_required_active_form(self, todo_component):
|
||||
"""Writing without active_form should fail."""
|
||||
result = todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Fix bug",
|
||||
"status": "pending",
|
||||
"active_form": "",
|
||||
}
|
||||
]
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "active_form" in result["message"]
|
||||
|
||||
def test_write_validates_status(self, todo_component):
|
||||
"""Writing with invalid status should fail."""
|
||||
result = todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Fix bug",
|
||||
"status": "invalid_status",
|
||||
"active_form": "Fixing bug",
|
||||
}
|
||||
]
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "status" in result["message"]
|
||||
|
||||
def test_write_enforces_max_items(self, todo_component):
|
||||
"""Writing more items than max_items should fail."""
|
||||
component = TodoComponent(config=TodoConfiguration(max_items=2))
|
||||
result = component.todo_write(
|
||||
[
|
||||
{"content": "Task 1", "status": "pending", "active_form": "Task 1"},
|
||||
{"content": "Task 2", "status": "pending", "active_form": "Task 2"},
|
||||
{"content": "Task 3", "status": "pending", "active_form": "Task 3"},
|
||||
]
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "Too many items" in result["message"]
|
||||
|
||||
|
||||
class TestTodoRead:
|
||||
"""Tests for the todo_read command."""
|
||||
|
||||
@@ -182,16 +21,10 @@ class TestTodoRead:
|
||||
assert result["items"] == []
|
||||
assert result["summary"]["pending"] == 0
|
||||
|
||||
def test_read_after_write(self, todo_component):
|
||||
"""Reading after writing should return the written items."""
|
||||
todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Fix bug",
|
||||
"status": "pending",
|
||||
"active_form": "Fixing bug",
|
||||
}
|
||||
]
|
||||
def test_read_after_add(self, todo_component):
|
||||
"""Reading after adding should return the added items."""
|
||||
todo_component.todo_add(
|
||||
content="Fix bug", active_form="Fixing bug", status="pending"
|
||||
)
|
||||
|
||||
result = todo_component.todo_read()
|
||||
@@ -213,10 +46,10 @@ class TestTodoClear:
|
||||
|
||||
def test_clear_populated_list(self, todo_component):
|
||||
"""Clearing a populated list should remove all items."""
|
||||
todo_component.todo_write(
|
||||
[
|
||||
{"content": "Task 1", "status": "pending", "active_form": "Task 1"},
|
||||
{"content": "Task 2", "status": "pending", "active_form": "Task 2"},
|
||||
todo_component.todo_bulk_add(
|
||||
items=[
|
||||
{"content": "Task 1", "active_form": "Task 1"},
|
||||
{"content": "Task 2", "active_form": "Task 2"},
|
||||
]
|
||||
)
|
||||
|
||||
@@ -241,14 +74,16 @@ class TestProtocols:
|
||||
def test_get_best_practices(self, todo_component):
|
||||
"""DirectiveProvider.get_best_practices should yield practices."""
|
||||
practices = list(todo_component.get_best_practices())
|
||||
assert len(practices) == 4
|
||||
assert any("todo_write" in p for p in practices)
|
||||
assert len(practices) == 6
|
||||
assert any("todo_bulk_add" in p for p in practices)
|
||||
assert any("todo_set_status" in p for p in practices)
|
||||
assert any("in_progress" in p for p in practices)
|
||||
|
||||
def test_get_commands(self, todo_component):
|
||||
"""CommandProvider.get_commands should yield commands."""
|
||||
commands = list(todo_component.get_commands())
|
||||
command_names = [c.names[0] for c in commands]
|
||||
assert "todo_write" in command_names
|
||||
assert "todo_add" in command_names
|
||||
assert "todo_read" in command_names
|
||||
assert "todo_clear" in command_names
|
||||
|
||||
@@ -259,17 +94,17 @@ class TestProtocols:
|
||||
|
||||
def test_get_messages_with_todos(self, todo_component):
|
||||
"""MessageProvider should include todos in LLM context."""
|
||||
todo_component.todo_write(
|
||||
[
|
||||
todo_component.todo_bulk_add(
|
||||
items=[
|
||||
{
|
||||
"content": "Implement feature",
|
||||
"status": "in_progress",
|
||||
"active_form": "Implementing feature",
|
||||
"status": "in_progress",
|
||||
},
|
||||
{
|
||||
"content": "Write tests",
|
||||
"status": "pending",
|
||||
"active_form": "Writing tests",
|
||||
"status": "pending",
|
||||
},
|
||||
]
|
||||
)
|
||||
@@ -287,9 +122,7 @@ class TestProtocols:
|
||||
def test_get_messages_respects_show_in_prompt_config(self):
|
||||
"""MessageProvider should respect show_in_prompt config."""
|
||||
component = TodoComponent(config=TodoConfiguration(show_in_prompt=False))
|
||||
component.todo_write(
|
||||
[{"content": "Task", "status": "pending", "active_form": "Task"}]
|
||||
)
|
||||
component.todo_add(content="Task", active_form="Task")
|
||||
|
||||
messages = list(component.get_messages())
|
||||
assert len(messages) == 0
|
||||
@@ -312,175 +145,13 @@ class TestConfiguration:
|
||||
assert component.config.show_in_prompt is False
|
||||
|
||||
|
||||
class TestSubItems:
|
||||
"""Tests for hierarchical sub-items support."""
|
||||
|
||||
def test_write_with_sub_items(self, todo_component):
|
||||
"""Writing todos with sub_items should succeed."""
|
||||
result = todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Implement feature",
|
||||
"status": "in_progress",
|
||||
"active_form": "Implementing feature",
|
||||
"sub_items": [
|
||||
{
|
||||
"content": "Design API",
|
||||
"status": "completed",
|
||||
"active_form": "Designing API",
|
||||
},
|
||||
{
|
||||
"content": "Write code",
|
||||
"status": "in_progress",
|
||||
"active_form": "Writing code",
|
||||
},
|
||||
{
|
||||
"content": "Add tests",
|
||||
"status": "pending",
|
||||
"active_form": "Adding tests",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
assert result["status"] == "success"
|
||||
assert result["item_count"] == 1
|
||||
|
||||
def test_read_returns_sub_items(self, todo_component):
|
||||
"""Reading should return sub_items."""
|
||||
todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Main task",
|
||||
"status": "in_progress",
|
||||
"active_form": "Working on main task",
|
||||
"sub_items": [
|
||||
{
|
||||
"content": "Sub task 1",
|
||||
"status": "completed",
|
||||
"active_form": "Doing sub task 1",
|
||||
},
|
||||
{
|
||||
"content": "Sub task 2",
|
||||
"status": "pending",
|
||||
"active_form": "Doing sub task 2",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
result = todo_component.todo_read()
|
||||
assert result["status"] == "success"
|
||||
assert len(result["items"]) == 1
|
||||
assert "sub_items" in result["items"][0]
|
||||
assert len(result["items"][0]["sub_items"]) == 2
|
||||
assert result["items"][0]["sub_items"][0]["content"] == "Sub task 1"
|
||||
assert result["items"][0]["sub_items"][0]["status"] == "completed"
|
||||
|
||||
def test_nested_sub_items(self, todo_component):
|
||||
"""Writing deeply nested sub_items should succeed."""
|
||||
result = todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Level 1",
|
||||
"status": "in_progress",
|
||||
"active_form": "Level 1",
|
||||
"sub_items": [
|
||||
{
|
||||
"content": "Level 2",
|
||||
"status": "pending",
|
||||
"active_form": "Level 2",
|
||||
"sub_items": [
|
||||
{
|
||||
"content": "Level 3",
|
||||
"status": "pending",
|
||||
"active_form": "Level 3",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
assert result["status"] == "success"
|
||||
|
||||
# Verify nested structure
|
||||
read_result = todo_component.todo_read()
|
||||
level1 = read_result["items"][0]
|
||||
level2 = level1["sub_items"][0]
|
||||
level3 = level2["sub_items"][0]
|
||||
assert level3["content"] == "Level 3"
|
||||
|
||||
def test_sub_items_validation_error(self, todo_component):
|
||||
"""Sub-items with invalid fields should fail validation."""
|
||||
result = todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Main task",
|
||||
"status": "pending",
|
||||
"active_form": "Main task",
|
||||
"sub_items": [
|
||||
{
|
||||
"content": "", # Invalid: empty content
|
||||
"status": "pending",
|
||||
"active_form": "Sub task",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "sub_items" in result["message"]
|
||||
|
||||
def test_messages_include_sub_items(self, todo_component):
|
||||
"""MessageProvider should format sub-items with indentation."""
|
||||
todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Main task",
|
||||
"status": "in_progress",
|
||||
"active_form": "Working on main task",
|
||||
"sub_items": [
|
||||
{
|
||||
"content": "Sub completed",
|
||||
"status": "completed",
|
||||
"active_form": "Sub completed",
|
||||
},
|
||||
{
|
||||
"content": "Sub pending",
|
||||
"status": "pending",
|
||||
"active_form": "Sub pending",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
messages = list(todo_component.get_messages())
|
||||
assert len(messages) == 1
|
||||
content = messages[0].content
|
||||
|
||||
# Check parent is shown
|
||||
assert "Working on main task" in content
|
||||
# Check sub-items are shown (with their status indicators)
|
||||
assert "[x] Sub completed" in content
|
||||
assert "[ ] Sub pending" in content
|
||||
|
||||
|
||||
class TestTodoDecompose:
|
||||
"""Tests for the todo_decompose command."""
|
||||
|
||||
def test_decompose_without_llm_provider(self, todo_component):
|
||||
"""Decompose should fail gracefully without LLM provider."""
|
||||
todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Complex task",
|
||||
"status": "pending",
|
||||
"active_form": "Complex task",
|
||||
}
|
||||
]
|
||||
todo_component.todo_add(
|
||||
content="Complex task", active_form="Complex task", status="pending"
|
||||
)
|
||||
|
||||
import asyncio
|
||||
@@ -491,20 +162,6 @@ class TestTodoDecompose:
|
||||
assert result["status"] == "error"
|
||||
assert "LLM provider not configured" in result["message"]
|
||||
|
||||
def test_decompose_invalid_index(self, todo_component):
|
||||
"""Decompose with invalid index should fail."""
|
||||
todo_component.todo_write(
|
||||
[{"content": "Task", "status": "pending", "active_form": "Task"}]
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
todo_component.todo_decompose(item_index=5)
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "Invalid item_index" in result["message"]
|
||||
|
||||
def test_decompose_empty_list(self, todo_component):
|
||||
"""Decompose on empty list should fail."""
|
||||
import asyncio
|
||||
@@ -514,35 +171,384 @@ class TestTodoDecompose:
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
|
||||
def test_decompose_already_has_sub_items(self, todo_component):
|
||||
"""Decompose should fail if item already has sub-items."""
|
||||
todo_component.todo_write(
|
||||
[
|
||||
{
|
||||
"content": "Task with subs",
|
||||
"status": "pending",
|
||||
"active_form": "Task with subs",
|
||||
"sub_items": [
|
||||
{
|
||||
"content": "Existing sub",
|
||||
"status": "pending",
|
||||
"active_form": "Existing sub",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
todo_component.todo_decompose(item_index=0)
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "already has" in result["message"]
|
||||
|
||||
def test_get_commands_includes_decompose(self, todo_component):
|
||||
"""CommandProvider should include todo_decompose command."""
|
||||
commands = list(todo_component.get_commands())
|
||||
command_names = [c.names[0] for c in commands]
|
||||
assert "todo_decompose" in command_names
|
||||
|
||||
|
||||
class TestTodoAdd:
|
||||
"""Tests for the todo_add incremental command."""
|
||||
|
||||
def test_add_single_todo(self, todo_component):
|
||||
"""Adding a single todo should succeed and return the item with ID."""
|
||||
result = todo_component.todo_add(
|
||||
content="Fix the bug", active_form="Fixing the bug"
|
||||
)
|
||||
assert result["status"] == "success"
|
||||
assert result["item"]["content"] == "Fix the bug"
|
||||
assert result["item"]["active_form"] == "Fixing the bug"
|
||||
assert result["item"]["status"] == "pending"
|
||||
assert "id" in result["item"]
|
||||
assert result["total_items"] == 1
|
||||
|
||||
def test_add_with_status(self, todo_component):
|
||||
"""Adding a todo with explicit status should work."""
|
||||
result = todo_component.todo_add(
|
||||
content="Task", active_form="Doing task", status="in_progress"
|
||||
)
|
||||
assert result["status"] == "success"
|
||||
assert result["item"]["status"] == "in_progress"
|
||||
|
||||
def test_add_at_index(self, todo_component):
|
||||
"""Adding a todo at specific index should insert correctly."""
|
||||
# Add two items first
|
||||
todo_component.todo_add(content="First", active_form="First")
|
||||
todo_component.todo_add(content="Third", active_form="Third")
|
||||
|
||||
# Insert at index 1
|
||||
result = todo_component.todo_add(
|
||||
content="Second", active_form="Second", index=1
|
||||
)
|
||||
assert result["status"] == "success"
|
||||
|
||||
# Verify order
|
||||
read_result = todo_component.todo_read()
|
||||
assert read_result["items"][0]["content"] == "First"
|
||||
assert read_result["items"][1]["content"] == "Second"
|
||||
assert read_result["items"][2]["content"] == "Third"
|
||||
|
||||
def test_add_validates_empty_content(self, todo_component):
|
||||
"""Adding with empty content should fail."""
|
||||
result = todo_component.todo_add(content="", active_form="Doing something")
|
||||
assert result["status"] == "error"
|
||||
assert "content" in result["message"]
|
||||
|
||||
def test_add_validates_empty_active_form(self, todo_component):
|
||||
"""Adding with empty active_form should fail."""
|
||||
result = todo_component.todo_add(content="Do something", active_form="")
|
||||
assert result["status"] == "error"
|
||||
assert "active_form" in result["message"]
|
||||
|
||||
def test_add_enforces_max_items(self):
|
||||
"""Adding should fail when max items reached."""
|
||||
component = TodoComponent(config=TodoConfiguration(max_items=2))
|
||||
component.todo_add(content="Task 1", active_form="Task 1")
|
||||
component.todo_add(content="Task 2", active_form="Task 2")
|
||||
|
||||
result = component.todo_add(content="Task 3", active_form="Task 3")
|
||||
assert result["status"] == "error"
|
||||
assert "max items" in result["message"]
|
||||
|
||||
|
||||
class TestTodoSetStatus:
|
||||
"""Tests for the todo_set_status incremental command."""
|
||||
|
||||
def test_set_status_pending_to_in_progress(self, todo_component):
|
||||
"""Changing status from pending to in_progress should work."""
|
||||
add_result = todo_component.todo_add(content="Task", active_form="Task")
|
||||
item_id = add_result["item"]["id"]
|
||||
|
||||
result = todo_component.todo_set_status(id=item_id, status="in_progress")
|
||||
assert result["status"] == "success"
|
||||
assert result["item"]["status"] == "in_progress"
|
||||
assert result["changed"]["status"]["from"] == "pending"
|
||||
assert result["changed"]["status"]["to"] == "in_progress"
|
||||
|
||||
def test_set_status_to_completed(self, todo_component):
|
||||
"""Marking a task as completed should work."""
|
||||
add_result = todo_component.todo_add(
|
||||
content="Task", active_form="Task", status="in_progress"
|
||||
)
|
||||
item_id = add_result["item"]["id"]
|
||||
|
||||
result = todo_component.todo_set_status(id=item_id, status="completed")
|
||||
assert result["status"] == "success"
|
||||
assert result["item"]["status"] == "completed"
|
||||
|
||||
def test_set_status_invalid_id(self, todo_component):
|
||||
"""Setting status with invalid ID should fail."""
|
||||
result = todo_component.todo_set_status(id="nonexistent", status="completed")
|
||||
assert result["status"] == "error"
|
||||
assert "not found" in result["message"]
|
||||
|
||||
|
||||
class TestTodoUpdate:
|
||||
"""Tests for the todo_update incremental command."""
|
||||
|
||||
def test_update_content_only(self, todo_component):
|
||||
"""Updating only content should preserve other fields."""
|
||||
add_result = todo_component.todo_add(
|
||||
content="Original", active_form="Original form", status="pending"
|
||||
)
|
||||
item_id = add_result["item"]["id"]
|
||||
|
||||
result = todo_component.todo_update(id=item_id, content="Updated")
|
||||
assert result["status"] == "success"
|
||||
assert result["item"]["content"] == "Updated"
|
||||
assert result["item"]["active_form"] == "Original form" # Unchanged
|
||||
assert result["item"]["status"] == "pending" # Unchanged
|
||||
assert "content" in result["changed"]
|
||||
assert "active_form" not in result["changed"]
|
||||
|
||||
def test_update_multiple_fields(self, todo_component):
|
||||
"""Updating multiple fields at once should work."""
|
||||
add_result = todo_component.todo_add(content="Task", active_form="Task")
|
||||
item_id = add_result["item"]["id"]
|
||||
|
||||
result = todo_component.todo_update(
|
||||
id=item_id,
|
||||
content="New content",
|
||||
active_form="New form",
|
||||
status="in_progress",
|
||||
)
|
||||
assert result["status"] == "success"
|
||||
assert result["item"]["content"] == "New content"
|
||||
assert result["item"]["active_form"] == "New form"
|
||||
assert result["item"]["status"] == "in_progress"
|
||||
assert len(result["changed"]) == 3
|
||||
|
||||
def test_update_no_changes(self, todo_component):
|
||||
"""Calling update with no changes should return success with message."""
|
||||
add_result = todo_component.todo_add(content="Task", active_form="Task")
|
||||
item_id = add_result["item"]["id"]
|
||||
|
||||
result = todo_component.todo_update(id=item_id)
|
||||
assert result["status"] == "success"
|
||||
assert "No changes" in result["message"]
|
||||
|
||||
def test_update_invalid_id(self, todo_component):
|
||||
"""Updating with invalid ID should fail."""
|
||||
result = todo_component.todo_update(id="nonexistent", content="New")
|
||||
assert result["status"] == "error"
|
||||
assert "not found" in result["message"]
|
||||
|
||||
def test_update_validates_empty_content(self, todo_component):
|
||||
"""Updating content to empty should fail."""
|
||||
add_result = todo_component.todo_add(content="Task", active_form="Task")
|
||||
item_id = add_result["item"]["id"]
|
||||
|
||||
result = todo_component.todo_update(id=item_id, content="")
|
||||
assert result["status"] == "error"
|
||||
assert "content" in result["message"]
|
||||
|
||||
|
||||
class TestTodoDelete:
|
||||
"""Tests for the todo_delete incremental command."""
|
||||
|
||||
def test_delete_existing_todo(self, todo_component):
|
||||
"""Deleting an existing todo should succeed."""
|
||||
add_result = todo_component.todo_add(content="Task", active_form="Task")
|
||||
item_id = add_result["item"]["id"]
|
||||
|
||||
result = todo_component.todo_delete(id=item_id)
|
||||
assert result["status"] == "success"
|
||||
assert result["deleted"]["id"] == item_id
|
||||
assert result["remaining_items"] == 0
|
||||
|
||||
# Verify it's gone
|
||||
read_result = todo_component.todo_read()
|
||||
assert len(read_result["items"]) == 0
|
||||
|
||||
def test_delete_from_middle(self, todo_component):
|
||||
"""Deleting from middle of list should preserve order."""
|
||||
todo_component.todo_add(content="First", active_form="First")
|
||||
add_result = todo_component.todo_add(content="Second", active_form="Second")
|
||||
todo_component.todo_add(content="Third", active_form="Third")
|
||||
|
||||
result = todo_component.todo_delete(id=add_result["item"]["id"])
|
||||
assert result["status"] == "success"
|
||||
assert result["remaining_items"] == 2
|
||||
|
||||
read_result = todo_component.todo_read()
|
||||
assert read_result["items"][0]["content"] == "First"
|
||||
assert read_result["items"][1]["content"] == "Third"
|
||||
|
||||
def test_delete_invalid_id(self, todo_component):
|
||||
"""Deleting with invalid ID should fail."""
|
||||
result = todo_component.todo_delete(id="nonexistent")
|
||||
assert result["status"] == "error"
|
||||
assert "not found" in result["message"]
|
||||
|
||||
|
||||
class TestTodoBulkAdd:
|
||||
"""Tests for the todo_bulk_add command."""
|
||||
|
||||
def test_bulk_add_multiple_items(self, todo_component):
|
||||
"""Bulk adding multiple items should succeed."""
|
||||
result = todo_component.todo_bulk_add(
|
||||
items=[
|
||||
{"content": "Task 1", "active_form": "Task 1"},
|
||||
{"content": "Task 2", "active_form": "Task 2", "status": "in_progress"},
|
||||
{"content": "Task 3", "active_form": "Task 3"},
|
||||
]
|
||||
)
|
||||
assert result["status"] == "success"
|
||||
assert result["added_count"] == 3
|
||||
assert result["total_items"] == 3
|
||||
assert len(result["added"]) == 3
|
||||
|
||||
# Each item should have an ID
|
||||
for item in result["added"]:
|
||||
assert "id" in item
|
||||
|
||||
def test_bulk_add_empty_list(self, todo_component):
|
||||
"""Bulk adding empty list should fail."""
|
||||
result = todo_component.todo_bulk_add(items=[])
|
||||
assert result["status"] == "error"
|
||||
assert "No items" in result["message"]
|
||||
|
||||
def test_bulk_add_validates_content(self, todo_component):
|
||||
"""Bulk add should validate each item's content."""
|
||||
result = todo_component.todo_bulk_add(
|
||||
items=[
|
||||
{"content": "Valid", "active_form": "Valid"},
|
||||
{"content": "", "active_form": "Invalid"},
|
||||
]
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "Item 1" in result["message"]
|
||||
assert "content" in result["message"]
|
||||
|
||||
def test_bulk_add_validates_active_form(self, todo_component):
|
||||
"""Bulk add should validate each item's active_form."""
|
||||
result = todo_component.todo_bulk_add(
|
||||
items=[
|
||||
{"content": "Valid", "active_form": ""},
|
||||
]
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "active_form" in result["message"]
|
||||
|
||||
def test_bulk_add_validates_status(self, todo_component):
|
||||
"""Bulk add should validate each item's status."""
|
||||
result = todo_component.todo_bulk_add(
|
||||
items=[
|
||||
{"content": "Task", "active_form": "Task", "status": "invalid"},
|
||||
]
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "status" in result["message"]
|
||||
|
||||
def test_bulk_add_enforces_max_items(self):
|
||||
"""Bulk add should respect max items limit."""
|
||||
component = TodoComponent(config=TodoConfiguration(max_items=2))
|
||||
|
||||
result = component.todo_bulk_add(
|
||||
items=[
|
||||
{"content": "Task 1", "active_form": "Task 1"},
|
||||
{"content": "Task 2", "active_form": "Task 2"},
|
||||
{"content": "Task 3", "active_form": "Task 3"},
|
||||
]
|
||||
)
|
||||
assert result["status"] == "error"
|
||||
assert "exceed max" in result["message"]
|
||||
|
||||
|
||||
class TestTodoReorder:
|
||||
"""Tests for the todo_reorder command."""
|
||||
|
||||
def test_reorder_todos(self, todo_component):
|
||||
"""Reordering todos should change their order."""
|
||||
r1 = todo_component.todo_add(content="First", active_form="First")
|
||||
r2 = todo_component.todo_add(content="Second", active_form="Second")
|
||||
r3 = todo_component.todo_add(content="Third", active_form="Third")
|
||||
|
||||
# Reverse the order
|
||||
result = todo_component.todo_reorder(
|
||||
ids=[r3["item"]["id"], r2["item"]["id"], r1["item"]["id"]]
|
||||
)
|
||||
assert result["status"] == "success"
|
||||
|
||||
read_result = todo_component.todo_read()
|
||||
assert read_result["items"][0]["content"] == "Third"
|
||||
assert read_result["items"][1]["content"] == "Second"
|
||||
assert read_result["items"][2]["content"] == "First"
|
||||
|
||||
def test_reorder_missing_ids(self, todo_component):
|
||||
"""Reorder with missing IDs should fail."""
|
||||
r1 = todo_component.todo_add(content="First", active_form="First")
|
||||
todo_component.todo_add(content="Second", active_form="Second")
|
||||
|
||||
result = todo_component.todo_reorder(ids=[r1["item"]["id"]])
|
||||
assert result["status"] == "error"
|
||||
assert "Missing todo IDs" in result["message"]
|
||||
|
||||
def test_reorder_unknown_ids(self, todo_component):
|
||||
"""Reorder with unknown IDs should fail."""
|
||||
r1 = todo_component.todo_add(content="First", active_form="First")
|
||||
|
||||
result = todo_component.todo_reorder(ids=[r1["item"]["id"], "unknown_id"])
|
||||
assert result["status"] == "error"
|
||||
assert "Unknown todo IDs" in result["message"]
|
||||
|
||||
def test_reorder_duplicate_ids(self, todo_component):
|
||||
"""Reorder with duplicate IDs should fail."""
|
||||
r1 = todo_component.todo_add(content="First", active_form="First")
|
||||
todo_component.todo_add(content="Second", active_form="Second")
|
||||
|
||||
result = todo_component.todo_reorder(ids=[r1["item"]["id"], r1["item"]["id"]])
|
||||
assert result["status"] == "error"
|
||||
assert "Duplicate" in result["message"]
|
||||
|
||||
|
||||
class TestTodoIdIntegration:
|
||||
"""Tests for ID functionality across operations."""
|
||||
|
||||
def test_ids_are_unique(self, todo_component):
|
||||
"""Each added todo should have a unique ID."""
|
||||
ids = set()
|
||||
for i in range(10):
|
||||
result = todo_component.todo_add(
|
||||
content=f"Task {i}", active_form=f"Task {i}"
|
||||
)
|
||||
ids.add(result["item"]["id"])
|
||||
|
||||
assert len(ids) == 10
|
||||
|
||||
def test_id_preserved_on_status_change(self, todo_component):
|
||||
"""ID should be preserved when status changes."""
|
||||
add_result = todo_component.todo_add(content="Task", active_form="Task")
|
||||
original_id = add_result["item"]["id"]
|
||||
|
||||
todo_component.todo_set_status(id=original_id, status="in_progress")
|
||||
todo_component.todo_set_status(id=original_id, status="completed")
|
||||
|
||||
read_result = todo_component.todo_read()
|
||||
assert read_result["items"][0]["id"] == original_id
|
||||
|
||||
def test_todo_read_includes_ids(self, todo_component):
|
||||
"""todo_read should return items with IDs."""
|
||||
todo_component.todo_add(content="Task", active_form="Task")
|
||||
|
||||
result = todo_component.todo_read()
|
||||
assert "id" in result["items"][0]
|
||||
|
||||
def test_bulk_add_generates_ids(self, todo_component):
|
||||
"""todo_bulk_add should generate IDs for items."""
|
||||
result = todo_component.todo_bulk_add(
|
||||
items=[{"content": "Task", "active_form": "Task"}]
|
||||
)
|
||||
assert "id" in result["added"][0]
|
||||
|
||||
read_result = todo_component.todo_read()
|
||||
assert "id" in read_result["items"][0]
|
||||
|
||||
|
||||
class TestIncrementalOperationsCommands:
|
||||
"""Tests for incremental operations being registered as commands."""
|
||||
|
||||
def test_all_incremental_commands_registered(self, todo_component):
|
||||
"""All incremental commands should be registered."""
|
||||
commands = list(todo_component.get_commands())
|
||||
command_names = [c.names[0] for c in commands]
|
||||
|
||||
assert "todo_add" in command_names
|
||||
assert "todo_set_status" in command_names
|
||||
assert "todo_update" in command_names
|
||||
assert "todo_delete" in command_names
|
||||
assert "todo_bulk_add" in command_names
|
||||
assert "todo_reorder" in command_names
|
||||
# todo_write is removed - incremental operations only
|
||||
assert "todo_write" not in command_names
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.app.config import AppConfig
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.file_storage.base import FileStorage
|
||||
from forge.llm.providers import MultiProvider
|
||||
from forge.permissions import CommandPermissionManager
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.app.config import AppConfig
|
||||
if TYPE_CHECKING:
|
||||
from forge.agent.execution_context import ExecutionContext
|
||||
|
||||
|
||||
def create_agent(
|
||||
@@ -19,6 +21,7 @@ def create_agent(
|
||||
ai_profile: Optional[AIProfile] = None,
|
||||
directives: Optional[AIDirectives] = None,
|
||||
permission_manager: Optional[CommandPermissionManager] = None,
|
||||
execution_context: Optional["ExecutionContext"] = None,
|
||||
) -> Agent:
|
||||
if not task:
|
||||
raise ValueError("No task specified for new agent")
|
||||
@@ -34,6 +37,7 @@ def create_agent(
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
permission_manager=permission_manager,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
return agent
|
||||
@@ -45,6 +49,7 @@ def configure_agent_with_state(
|
||||
file_storage: FileStorage,
|
||||
llm_provider: MultiProvider,
|
||||
permission_manager: Optional[CommandPermissionManager] = None,
|
||||
execution_context: Optional["ExecutionContext"] = None,
|
||||
) -> Agent:
|
||||
return _configure_agent(
|
||||
state=state,
|
||||
@@ -52,6 +57,7 @@ def configure_agent_with_state(
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
permission_manager=permission_manager,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -65,6 +71,7 @@ def _configure_agent(
|
||||
directives: Optional[AIDirectives] = None,
|
||||
state: Optional[AgentSettings] = None,
|
||||
permission_manager: Optional[CommandPermissionManager] = None,
|
||||
execution_context: Optional["ExecutionContext"] = None,
|
||||
) -> Agent:
|
||||
if state:
|
||||
agent_state = state
|
||||
@@ -88,6 +95,7 @@ def _configure_agent(
|
||||
file_storage=file_storage,
|
||||
app_config=app_config,
|
||||
permission_manager=permission_manager,
|
||||
execution_context=execution_context,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
"""Default implementation of AgentFactory for sub-agent spawning.
|
||||
|
||||
This factory creates Agent instances for use as sub-agents within
|
||||
a prompt strategy. It follows the same pattern as the direct_benchmark
|
||||
runner for agent creation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from forge.agent.execution_context import AgentFactory, ExecutionContext
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.app.config import AppConfig
|
||||
|
||||
|
||||
class DefaultAgentFactory(AgentFactory):
|
||||
"""Default implementation of AgentFactory.
|
||||
|
||||
Creates Agent instances for sub-agent spawning. Reuses the pattern
|
||||
from direct_benchmark/runner.py for agent creation.
|
||||
|
||||
The factory is stateless - all configuration comes from the AppConfig
|
||||
and ExecutionContext.
|
||||
"""
|
||||
|
||||
def __init__(self, app_config: "AppConfig"):
|
||||
"""Initialize the factory.
|
||||
|
||||
Args:
|
||||
app_config: The application configuration to use for
|
||||
creating agents. This provides LLM settings, disabled
|
||||
commands, etc.
|
||||
"""
|
||||
self.app_config = app_config
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
task: str,
|
||||
context: ExecutionContext,
|
||||
ai_profile: Optional[AIProfile] = None,
|
||||
directives: Optional[AIDirectives] = None,
|
||||
strategy: Optional[str] = None,
|
||||
) -> Agent:
|
||||
"""Create a new agent instance for sub-agent execution.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for the agent.
|
||||
task: The task the agent should accomplish.
|
||||
context: Execution context with shared resources.
|
||||
ai_profile: Optional AI profile override. If not provided,
|
||||
a default profile is created.
|
||||
directives: Optional directives override. If not provided,
|
||||
default directives are used.
|
||||
strategy: Optional strategy name override (e.g., "one_shot").
|
||||
If not provided, uses the app_config default.
|
||||
|
||||
Returns:
|
||||
A new Agent instance configured for the task.
|
||||
"""
|
||||
# Create default profile if not provided
|
||||
if ai_profile is None:
|
||||
ai_profile = AIProfile(
|
||||
ai_name=f"SubAgent-{agent_id[:8]}",
|
||||
ai_role="A specialized sub-agent working on a specific task.",
|
||||
)
|
||||
|
||||
# Create default directives if not provided
|
||||
if directives is None:
|
||||
directives = AIDirectives(
|
||||
constraints=[
|
||||
"Focus only on the assigned task.",
|
||||
"Do not ask for user input - work autonomously.",
|
||||
"Complete the task efficiently and call finish when done.",
|
||||
],
|
||||
resources=[
|
||||
"The same tools as your parent agent.",
|
||||
],
|
||||
best_practices=[
|
||||
"Think step by step.",
|
||||
"Be concise in your outputs.",
|
||||
],
|
||||
)
|
||||
|
||||
# Create agent settings
|
||||
agent_state = self._create_agent_state(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=directives,
|
||||
)
|
||||
|
||||
# Copy app config and optionally override strategy
|
||||
config = self.app_config.model_copy(deep=True)
|
||||
if strategy:
|
||||
config.prompt_strategy = strategy
|
||||
|
||||
# Sub-agents should always be non-interactive
|
||||
config.noninteractive_mode = True
|
||||
config.continuous_mode = True
|
||||
|
||||
# Create the agent with the provided execution context
|
||||
return Agent(
|
||||
settings=agent_state,
|
||||
llm_provider=context.llm_provider,
|
||||
file_storage=context.file_storage,
|
||||
app_config=config,
|
||||
execution_context=context,
|
||||
)
|
||||
|
||||
def _create_agent_state(
|
||||
self,
|
||||
agent_id: str,
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
directives: AIDirectives,
|
||||
) -> AgentSettings:
|
||||
"""Create the agent settings/state object.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for the agent.
|
||||
task: The task the agent should accomplish.
|
||||
ai_profile: The AI profile for this agent.
|
||||
directives: The directives for this agent.
|
||||
|
||||
Returns:
|
||||
AgentSettings configured for the sub-agent.
|
||||
"""
|
||||
return AgentSettings(
|
||||
agent_id=agent_id,
|
||||
name=Agent.default_settings.name,
|
||||
description=Agent.default_settings.description,
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=directives,
|
||||
config=AgentConfiguration(
|
||||
fast_llm=self.app_config.fast_llm,
|
||||
smart_llm=self.app_config.smart_llm,
|
||||
allow_fs_access=not self.app_config.restrict_to_workspace,
|
||||
),
|
||||
history=Agent.default_settings.history.model_copy(deep=True),
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from autogpt.app.config import AppConfig
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.llm.prompting import ChatPrompt, LanguageModelClassification, PromptStrategy
|
||||
@@ -13,8 +14,6 @@ from forge.llm.providers.schema import (
|
||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
|
||||
from autogpt.app.config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
|
||||
import sentry_sdk
|
||||
from pydantic import Field
|
||||
|
||||
from forge.agent.base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
from forge.agent.execution_context import ExecutionContext
|
||||
from forge.agent.protocols import (
|
||||
AfterExecute,
|
||||
AfterParse,
|
||||
@@ -63,8 +66,9 @@ from forge.utils.exceptions import (
|
||||
CommandExecutionError,
|
||||
UnknownCommandError,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from .prompt_strategies.lats import LATSActionProposal
|
||||
from .prompt_strategies.multi_agent_debate import DebateActionProposal
|
||||
from .prompt_strategies.one_shot import (
|
||||
OneShotAgentActionProposal,
|
||||
OneShotAgentPromptStrategy,
|
||||
@@ -90,6 +94,8 @@ AnyActionProposal = (
|
||||
| ReWOOActionProposal
|
||||
| ReflexionActionProposal
|
||||
| ToTActionProposal
|
||||
| LATSActionProposal
|
||||
| DebateActionProposal
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -128,11 +134,27 @@ class Agent(BaseAgent[AnyActionProposal], Configurable[AgentSettings]):
|
||||
file_storage: FileStorage,
|
||||
app_config: AppConfig,
|
||||
permission_manager: Optional[CommandPermissionManager] = None,
|
||||
execution_context: Optional[ExecutionContext] = None,
|
||||
):
|
||||
super().__init__(settings, permission_manager=permission_manager)
|
||||
|
||||
self.llm_provider = llm_provider
|
||||
self.app_config = app_config
|
||||
|
||||
# Create or use provided execution context
|
||||
if execution_context:
|
||||
self.execution_context = execution_context
|
||||
else:
|
||||
# Root agent - create new context
|
||||
self.execution_context = self._create_root_execution_context(
|
||||
llm_provider, file_storage, app_config
|
||||
)
|
||||
|
||||
# Create prompt strategy and inject execution context
|
||||
self.prompt_strategy = self._create_prompt_strategy(app_config)
|
||||
if hasattr(self.prompt_strategy, "set_execution_context"):
|
||||
self.prompt_strategy.set_execution_context(self.execution_context)
|
||||
|
||||
self.commands: list[Command] = []
|
||||
|
||||
# Components
|
||||
@@ -181,7 +203,40 @@ class Agent(BaseAgent[AnyActionProposal], Configurable[AgentSettings]):
|
||||
)
|
||||
|
||||
self.event_history = settings.history
|
||||
self.app_config = app_config
|
||||
|
||||
def _create_root_execution_context(
|
||||
self,
|
||||
llm_provider: MultiProvider,
|
||||
file_storage: FileStorage,
|
||||
app_config: AppConfig,
|
||||
) -> ExecutionContext:
|
||||
"""Create execution context for a root (top-level) agent.
|
||||
|
||||
Root agents create their own execution context with:
|
||||
- Full access to shared resources
|
||||
- Default resource budget
|
||||
- An agent factory for spawning sub-agents
|
||||
|
||||
Args:
|
||||
llm_provider: The LLM provider instance.
|
||||
file_storage: The file storage instance.
|
||||
app_config: The application configuration.
|
||||
|
||||
Returns:
|
||||
A new ExecutionContext for this root agent.
|
||||
"""
|
||||
from autogpt.agent_factory.default_factory import DefaultAgentFactory
|
||||
|
||||
factory = DefaultAgentFactory(app_config)
|
||||
|
||||
return ExecutionContext(
|
||||
llm_provider=llm_provider,
|
||||
file_storage=file_storage,
|
||||
agent_factory=factory,
|
||||
parent_agent_id=None, # Root agent has no parent
|
||||
depth=0,
|
||||
_app_config=app_config,
|
||||
)
|
||||
|
||||
async def propose_action(self) -> AnyActionProposal:
|
||||
"""Proposes the next action to execute, based on the task and current state.
|
||||
@@ -458,6 +513,22 @@ class Agent(BaseAgent[AnyActionProposal], Configurable[AgentSettings]):
|
||||
tot_config.use_prefill = use_prefill
|
||||
return TreeOfThoughtsPromptStrategy(tot_config, logger)
|
||||
|
||||
elif strategy_name == "lats":
|
||||
from .prompt_strategies.lats import LATSPromptStrategy
|
||||
|
||||
lats_config = LATSPromptStrategy.default_configuration.model_copy(deep=True)
|
||||
lats_config.use_prefill = use_prefill
|
||||
return LATSPromptStrategy(lats_config, logger)
|
||||
|
||||
elif strategy_name == "multi_agent_debate":
|
||||
from .prompt_strategies.multi_agent_debate import MultiAgentDebateStrategy
|
||||
|
||||
debate_config = MultiAgentDebateStrategy.default_configuration.model_copy(
|
||||
deep=True
|
||||
)
|
||||
debate_config.use_prefill = use_prefill
|
||||
return MultiAgentDebateStrategy(debate_config, logger)
|
||||
|
||||
else: # Default to one_shot
|
||||
os_config = OneShotAgentPromptStrategy.default_configuration.model_copy(
|
||||
deep=True
|
||||
|
||||
@@ -3,9 +3,8 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from forge.file_storage.base import FileStorage
|
||||
|
||||
from autogpt.agents.agent import AgentSettings
|
||||
from forge.file_storage.base import FileStorage
|
||||
|
||||
|
||||
class AgentManager:
|
||||
|
||||
@@ -6,6 +6,7 @@ implementations including ReWOO, Plan-and-Execute, Reflexion, and Tree of Though
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import enum
|
||||
import platform
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -14,6 +15,14 @@ from logging import Logger
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar
|
||||
|
||||
import distro
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from forge.agent.execution_context import (
|
||||
ExecutionContext,
|
||||
SubAgentHandle,
|
||||
SubAgentStatus,
|
||||
generate_sub_agent_id,
|
||||
)
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.llm.prompting import ChatPrompt, LanguageModelClassification
|
||||
@@ -26,7 +35,6 @@ from forge.llm.providers.schema import (
|
||||
from forge.models.action import ActionProposal
|
||||
from forge.models.config import SystemConfiguration, UserConfigurable
|
||||
from forge.models.utils import ModelWithSummary
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -40,6 +48,8 @@ class PromptStrategyType(str, enum.Enum):
|
||||
PLAN_EXECUTE = "plan_execute"
|
||||
REFLEXION = "reflexion"
|
||||
TREE_OF_THOUGHTS = "tree_of_thoughts"
|
||||
LATS = "lats" # Language Agent Tree Search (sub-agent based)
|
||||
MULTI_AGENT_DEBATE = "multi_agent_debate" # Multi-agent debate (sub-agent based)
|
||||
|
||||
|
||||
class PlannedStep(BaseModel):
|
||||
@@ -248,12 +258,27 @@ class BasePromptStrategyConfiguration(SystemConfiguration):
|
||||
body_template: str = UserConfigurable(default=DEFAULT_BODY_TEMPLATE)
|
||||
use_prefill: bool = True
|
||||
|
||||
# Sub-agent configuration
|
||||
enable_sub_agents: bool = UserConfigurable(default=True)
|
||||
"""Enable sub-agent spawning for this strategy."""
|
||||
|
||||
max_sub_agents: int = UserConfigurable(default=5)
|
||||
"""Maximum number of sub-agents that can be spawned."""
|
||||
|
||||
sub_agent_timeout_seconds: int = UserConfigurable(default=300)
|
||||
"""Timeout for sub-agent execution in seconds."""
|
||||
|
||||
sub_agent_max_cycles: int = UserConfigurable(default=25)
|
||||
"""Maximum execution cycles per sub-agent."""
|
||||
|
||||
|
||||
class BaseMultiStepPromptStrategy(ABC):
|
||||
"""Base class for multi-step prompt strategies.
|
||||
|
||||
Provides common utilities for strategies that involve multiple phases
|
||||
like planning, execution, synthesis, or reflection.
|
||||
|
||||
Also provides sub-agent spawning capabilities when enabled via config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -263,6 +288,300 @@ class BaseMultiStepPromptStrategy(ABC):
|
||||
):
|
||||
self.config = configuration
|
||||
self.logger = logger
|
||||
self._execution_context: Optional[ExecutionContext] = None
|
||||
|
||||
# ===== Sub-Agent Support Methods =====
|
||||
|
||||
def set_execution_context(self, context: ExecutionContext) -> None:
|
||||
"""Inject execution context. Called by Agent after creation.
|
||||
|
||||
This provides the strategy with access to shared resources needed
|
||||
for sub-agent spawning (LLM provider, file storage, agent factory).
|
||||
|
||||
Args:
|
||||
context: The execution context from the parent agent.
|
||||
"""
|
||||
self._execution_context = context
|
||||
self.logger.debug(
|
||||
f"ExecutionContext set (depth={context.depth}, "
|
||||
f"sub_agents_enabled={self.config.enable_sub_agents})"
|
||||
)
|
||||
|
||||
def can_spawn_sub_agent(self) -> bool:
|
||||
"""Check if sub-agent spawning is available and allowed.
|
||||
|
||||
Returns:
|
||||
True if sub-agents can be spawned, False otherwise.
|
||||
"""
|
||||
if not self.config.enable_sub_agents:
|
||||
return False
|
||||
if self._execution_context is None:
|
||||
return False
|
||||
return self._execution_context.can_spawn_sub_agent()
|
||||
|
||||
async def spawn_sub_agent(
|
||||
self,
|
||||
task: str,
|
||||
ai_profile: Optional[AIProfile] = None,
|
||||
directives: Optional[AIDirectives] = None,
|
||||
strategy: Optional[str] = None,
|
||||
) -> SubAgentHandle:
|
||||
"""Spawn a sub-agent to handle a subtask.
|
||||
|
||||
The sub-agent runs with its own execution context (reduced budget,
|
||||
restricted file storage) and can be run synchronously or in background.
|
||||
|
||||
Args:
|
||||
task: The task for the sub-agent to accomplish.
|
||||
ai_profile: Optional AI profile override.
|
||||
directives: Optional directives override.
|
||||
strategy: Optional strategy name override (e.g., "one_shot").
|
||||
|
||||
Returns:
|
||||
A SubAgentHandle for tracking and interacting with the sub-agent.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sub-agent spawning is not available.
|
||||
"""
|
||||
if not self.can_spawn_sub_agent():
|
||||
raise RuntimeError(
|
||||
"Cannot spawn sub-agent: "
|
||||
+ ("not enabled" if not self.config.enable_sub_agents else "no context")
|
||||
)
|
||||
|
||||
assert self._execution_context is not None
|
||||
|
||||
# Generate unique ID for sub-agent
|
||||
parent_id = self._execution_context.parent_agent_id
|
||||
agent_id = generate_sub_agent_id(parent_id)
|
||||
|
||||
# Create handle
|
||||
handle = SubAgentHandle(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
status=SubAgentStatus.PENDING,
|
||||
)
|
||||
|
||||
# Create child context with restricted resources
|
||||
child_context = self._execution_context.create_child_context(agent_id)
|
||||
|
||||
# Create the sub-agent via factory
|
||||
factory = self._execution_context.agent_factory
|
||||
if factory is None:
|
||||
raise RuntimeError("No agent factory available")
|
||||
|
||||
try:
|
||||
agent = factory.create_agent(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
context=child_context,
|
||||
ai_profile=ai_profile,
|
||||
directives=directives,
|
||||
strategy=strategy,
|
||||
)
|
||||
handle._agent = agent
|
||||
handle.status = SubAgentStatus.PENDING
|
||||
|
||||
# Register with parent context
|
||||
self._execution_context.register_sub_agent(handle)
|
||||
|
||||
self.logger.info(f"Spawned sub-agent {agent_id} for task: {task[:100]}...")
|
||||
|
||||
except Exception as e:
|
||||
handle.status = SubAgentStatus.FAILED
|
||||
handle.error = str(e)
|
||||
self.logger.error(f"Failed to spawn sub-agent: {e}")
|
||||
|
||||
return handle
|
||||
|
||||
async def run_sub_agent(
|
||||
self,
|
||||
handle: SubAgentHandle,
|
||||
max_cycles: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""Run a sub-agent until completion.
|
||||
|
||||
Executes the sub-agent's action loop until it finishes or hits
|
||||
the cycle limit.
|
||||
|
||||
Args:
|
||||
handle: The sub-agent handle from spawn_sub_agent().
|
||||
max_cycles: Maximum cycles to run (default from config).
|
||||
|
||||
Returns:
|
||||
The result from the sub-agent (typically the finish command output).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the sub-agent is not in a runnable state.
|
||||
"""
|
||||
if handle._agent is None:
|
||||
raise RuntimeError(f"Sub-agent {handle.agent_id} has no agent instance")
|
||||
|
||||
if handle.status not in (SubAgentStatus.PENDING, SubAgentStatus.RUNNING):
|
||||
raise RuntimeError(
|
||||
f"Sub-agent {handle.agent_id} is not runnable "
|
||||
f"(status={handle.status})"
|
||||
)
|
||||
|
||||
max_cycles = max_cycles or self.config.sub_agent_max_cycles
|
||||
timeout = self.config.sub_agent_timeout_seconds
|
||||
|
||||
handle.status = SubAgentStatus.RUNNING
|
||||
agent = handle._agent
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._run_agent_loop(agent, max_cycles, handle),
|
||||
timeout=timeout,
|
||||
)
|
||||
handle.result = result
|
||||
handle.status = SubAgentStatus.COMPLETED
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
handle.status = SubAgentStatus.FAILED
|
||||
handle.error = f"Timed out after {timeout}s"
|
||||
self.logger.warning(f"Sub-agent {handle.agent_id} timed out")
|
||||
return None
|
||||
|
||||
except asyncio.CancelledError:
|
||||
handle.status = SubAgentStatus.CANCELLED
|
||||
self.logger.info(f"Sub-agent {handle.agent_id} was cancelled")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
handle.status = SubAgentStatus.FAILED
|
||||
handle.error = str(e)
|
||||
self.logger.error(f"Sub-agent {handle.agent_id} failed: {e}")
|
||||
return None
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
agent: Any,
|
||||
max_cycles: int,
|
||||
handle: SubAgentHandle,
|
||||
) -> Any:
|
||||
"""Run the agent's propose/execute loop.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
max_cycles: Maximum cycles to run.
|
||||
handle: The sub-agent handle for status tracking.
|
||||
|
||||
Returns:
|
||||
The result from the finish command, or None if max cycles reached.
|
||||
"""
|
||||
for cycle in range(max_cycles):
|
||||
# Check for cancellation
|
||||
if self._execution_context and self._execution_context.cancelled:
|
||||
handle.status = SubAgentStatus.CANCELLED
|
||||
return None
|
||||
|
||||
# Propose next action
|
||||
proposal = await agent.propose_action()
|
||||
|
||||
# Check for finish command
|
||||
if proposal.use_tool.name == "finish":
|
||||
# Extract result from finish arguments
|
||||
result = proposal.use_tool.arguments.get("reason", "")
|
||||
handle.summary = result[:200] if result else "Task completed"
|
||||
return result
|
||||
|
||||
# Execute the action
|
||||
result = await agent.execute(proposal)
|
||||
|
||||
# Log progress
|
||||
self.logger.debug(
|
||||
f"Sub-agent {handle.agent_id} cycle {cycle + 1}: "
|
||||
f"{proposal.use_tool.name}"
|
||||
)
|
||||
|
||||
# Hit max cycles
|
||||
handle.summary = f"Reached max cycles ({max_cycles})"
|
||||
return None
|
||||
|
||||
async def spawn_and_run(
|
||||
self,
|
||||
task: str,
|
||||
ai_profile: Optional[AIProfile] = None,
|
||||
directives: Optional[AIDirectives] = None,
|
||||
strategy: Optional[str] = None,
|
||||
max_cycles: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""Convenience method: spawn and immediately run a sub-agent.
|
||||
|
||||
This is the most common pattern for sub-agent usage.
|
||||
|
||||
Args:
|
||||
task: The task for the sub-agent.
|
||||
ai_profile: Optional AI profile override.
|
||||
directives: Optional directives override.
|
||||
strategy: Optional strategy name override.
|
||||
max_cycles: Maximum cycles to run.
|
||||
|
||||
Returns:
|
||||
The result from the sub-agent.
|
||||
"""
|
||||
handle = await self.spawn_sub_agent(
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=directives,
|
||||
strategy=strategy,
|
||||
)
|
||||
|
||||
if handle.status == SubAgentStatus.FAILED:
|
||||
self.logger.error(f"Failed to spawn sub-agent: {handle.error}")
|
||||
return None
|
||||
|
||||
return await self.run_sub_agent(handle, max_cycles=max_cycles)
|
||||
|
||||
async def run_parallel(
|
||||
self,
|
||||
tasks: list[str],
|
||||
strategy: Optional[str] = None,
|
||||
max_cycles: Optional[int] = None,
|
||||
) -> list[Any]:
|
||||
"""Run multiple sub-agents in parallel.
|
||||
|
||||
Useful for patterns like multi-agent debate or parallel exploration.
|
||||
|
||||
Args:
|
||||
tasks: List of tasks for sub-agents.
|
||||
strategy: Optional strategy name for all sub-agents.
|
||||
max_cycles: Maximum cycles per sub-agent.
|
||||
|
||||
Returns:
|
||||
List of results from all sub-agents (in same order as tasks).
|
||||
"""
|
||||
# Spawn all sub-agents
|
||||
handles = []
|
||||
for task in tasks:
|
||||
handle = await self.spawn_sub_agent(task=task, strategy=strategy)
|
||||
handles.append(handle)
|
||||
|
||||
# Run all in parallel
|
||||
async def run_one(h: SubAgentHandle) -> Any:
|
||||
if h.status == SubAgentStatus.FAILED:
|
||||
return None
|
||||
return await self.run_sub_agent(h, max_cycles=max_cycles)
|
||||
|
||||
results = await asyncio.gather(*[run_one(h) for h in handles])
|
||||
return list(results)
|
||||
|
||||
def get_sub_agent_results(self) -> dict[str, Any]:
|
||||
"""Get results from all completed sub-agents.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping agent_id to result.
|
||||
"""
|
||||
if self._execution_context is None:
|
||||
return {}
|
||||
|
||||
return {
|
||||
agent_id: handle.result
|
||||
for agent_id, handle in self._execution_context.sub_agents.items()
|
||||
if handle.status == SubAgentStatus.COMPLETED
|
||||
}
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -377,6 +696,14 @@ def get_strategy_class(
|
||||
from .tree_of_thoughts import TreeOfThoughtsPromptStrategy
|
||||
|
||||
return TreeOfThoughtsPromptStrategy
|
||||
elif strategy_type == PromptStrategyType.LATS:
|
||||
from .lats import LATSPromptStrategy
|
||||
|
||||
return LATSPromptStrategy
|
||||
elif strategy_type == PromptStrategyType.MULTI_AGENT_DEBATE:
|
||||
from .multi_agent_debate import MultiAgentDebateStrategy
|
||||
|
||||
return MultiAgentDebateStrategy
|
||||
|
||||
if strategy_type not in strategy_map:
|
||||
raise ValueError(f"Unknown strategy type: {strategy_type}")
|
||||
|
||||
@@ -0,0 +1,488 @@
|
||||
"""LATS (Language Agent Tree Search) prompt strategy.
|
||||
|
||||
This strategy implements the LATS algorithm from the paper:
|
||||
"Language Agent Tree Search Unifies Reasoning Acting and Planning in Language Models"
|
||||
|
||||
LATS uses sub-agents to explore different reasoning paths with Monte Carlo Tree Search,
|
||||
combining the benefits of tree search with LLM-based evaluation.
|
||||
|
||||
Key features:
|
||||
- Sub-agents explore different action paths in parallel
|
||||
- Monte Carlo Tree Search for intelligent exploration
|
||||
- Value function learned from sub-agent outcomes
|
||||
- Reflection on failed paths to improve future exploration
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from enum import Enum
|
||||
from logging import Logger
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
from forge.llm.prompting import ChatPrompt, LanguageModelClassification
|
||||
from forge.llm.providers.schema import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
from forge.models.action import ActionProposal
|
||||
from forge.models.config import UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.models.utils import ModelWithSummary
|
||||
from forge.utils.exceptions import InvalidAgentResponseError
|
||||
|
||||
from .base import BaseMultiStepPromptStrategy, BasePromptStrategyConfiguration
|
||||
|
||||
|
||||
class LATSPhase(str, Enum):
|
||||
"""Phases of the LATS algorithm."""
|
||||
|
||||
SELECTION = "selection" # Select node to expand using UCT
|
||||
EXPANSION = "expansion" # Generate candidate actions via sub-agents
|
||||
EVALUATION = "evaluation" # Evaluate candidates
|
||||
BACKPROPAGATION = "backpropagation" # Update value estimates
|
||||
EXECUTION = "execution" # Execute best action
|
||||
|
||||
|
||||
class LATSNode(BaseModel):
|
||||
"""A node in the LATS search tree."""
|
||||
|
||||
state_description: str = Field(description="Description of the state at this node")
|
||||
action_taken: Optional[str] = Field(
|
||||
default=None, description="Action that led to this state"
|
||||
)
|
||||
value: float = Field(default=0.0, description="Estimated value (Q-value)")
|
||||
visits: int = Field(default=0, description="Number of times this node was visited")
|
||||
children: list["LATSNode"] = Field(default_factory=list)
|
||||
parent: Optional["LATSNode"] = Field(default=None, exclude=True)
|
||||
depth: int = Field(default=0)
|
||||
is_terminal: bool = Field(default=False)
|
||||
reward: float = Field(default=0.0, description="Reward received at this node")
|
||||
reflection: str = Field(default="", description="Reflection on failures")
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
def uct_score(self, exploration_weight: float = 1.41) -> float:
|
||||
"""Calculate UCT (Upper Confidence Bound for Trees) score."""
|
||||
if self.visits == 0:
|
||||
return float("inf") # Encourage exploration of unvisited nodes
|
||||
|
||||
if self.parent is None or self.parent.visits == 0:
|
||||
return self.value
|
||||
|
||||
import math
|
||||
|
||||
exploitation = self.value / self.visits
|
||||
exploration = exploration_weight * math.sqrt(
|
||||
math.log(self.parent.visits) / self.visits
|
||||
)
|
||||
return exploitation + exploration
|
||||
|
||||
|
||||
class LATSThoughts(ModelWithSummary):
|
||||
"""Thoughts for LATS strategy."""
|
||||
|
||||
observations: str = Field(description="Current observations from the state")
|
||||
reasoning: str = Field(description="Reasoning about which path to take")
|
||||
candidate_actions: list[str] = Field(
|
||||
default_factory=list, description="Candidate actions being considered"
|
||||
)
|
||||
selected_action: str = Field(description="The action selected for execution")
|
||||
confidence: float = Field(
|
||||
default=0.5, description="Confidence in the selected action (0-1)"
|
||||
)
|
||||
|
||||
def summary(self) -> str:
|
||||
return self.selected_action
|
||||
|
||||
|
||||
class LATSActionProposal(ActionProposal):
|
||||
"""Action proposal for LATS strategy."""
|
||||
|
||||
thoughts: LATSThoughts # type: ignore
|
||||
|
||||
|
||||
class LATSPromptConfiguration(BasePromptStrategyConfiguration):
|
||||
"""Configuration for LATS strategy."""
|
||||
|
||||
# MCTS parameters
|
||||
num_candidates: int = UserConfigurable(default=3)
|
||||
"""Number of candidate actions to generate per expansion."""
|
||||
|
||||
max_depth: int = UserConfigurable(default=5)
|
||||
"""Maximum depth of the search tree."""
|
||||
|
||||
exploration_weight: float = UserConfigurable(default=1.41)
|
||||
"""UCT exploration weight (sqrt(2) is theoretically optimal)."""
|
||||
|
||||
num_simulations: int = UserConfigurable(default=3)
|
||||
"""Number of MCTS simulations per decision."""
|
||||
|
||||
# Sub-agent configuration (inherited, but with LATS-specific defaults)
|
||||
enable_sub_agents: bool = UserConfigurable(default=True)
|
||||
max_sub_agents: int = UserConfigurable(default=10)
|
||||
sub_agent_timeout_seconds: int = UserConfigurable(default=120)
|
||||
sub_agent_max_cycles: int = UserConfigurable(default=10)
|
||||
|
||||
DEFAULT_EXPANSION_INSTRUCTION: str = (
|
||||
"You are exploring possible actions for a task. "
|
||||
"Generate {num_candidates} distinct candidate actions "
|
||||
"that could make progress.\n\n"
|
||||
"Current state: {state}\n"
|
||||
"Task: {task}\n\n"
|
||||
"For each candidate, provide:\n"
|
||||
"1. The action name and arguments\n"
|
||||
"2. Expected outcome\n"
|
||||
"3. Potential risks\n\n"
|
||||
"Format as JSON array of objects with "
|
||||
"'action', 'expected_outcome', 'risks' keys."
|
||||
)
|
||||
|
||||
DEFAULT_EVALUATION_INSTRUCTION: str = (
|
||||
"Evaluate the following action outcome.\n\n"
|
||||
"Action: {action}\n"
|
||||
"Result: {result}\n"
|
||||
"Task goal: {task}\n\n"
|
||||
"Provide a score from 0.0 to 1.0 indicating progress toward the goal.\n"
|
||||
"Also provide a brief reflection on what worked or didn't work.\n\n"
|
||||
"Format: {{'score': 0.X, 'reflection': '...'}}"
|
||||
)
|
||||
|
||||
expansion_instruction: str = UserConfigurable(default=DEFAULT_EXPANSION_INSTRUCTION)
|
||||
evaluation_instruction: str = UserConfigurable(
|
||||
default=DEFAULT_EVALUATION_INSTRUCTION
|
||||
)
|
||||
|
||||
|
||||
class LATSPromptStrategy(BaseMultiStepPromptStrategy):
|
||||
"""LATS (Language Agent Tree Search) prompt strategy.
|
||||
|
||||
Uses sub-agents to explore different action paths with MCTS,
|
||||
combining tree search with LLM-based value estimation.
|
||||
"""
|
||||
|
||||
default_configuration = LATSPromptConfiguration()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configuration: LATSPromptConfiguration,
|
||||
logger: Logger,
|
||||
):
|
||||
super().__init__(configuration, logger)
|
||||
self.config: LATSPromptConfiguration = configuration
|
||||
self.response_schema = JSONSchema.from_dict(
|
||||
LATSActionProposal.model_json_schema()
|
||||
)
|
||||
|
||||
# LATS state
|
||||
self.root: Optional[LATSNode] = None
|
||||
self.current_node: Optional[LATSNode] = None
|
||||
self.phase = LATSPhase.SELECTION
|
||||
self.simulation_count = 0
|
||||
self.candidate_actions: list[dict[str, Any]] = []
|
||||
|
||||
@property
|
||||
def llm_classification(self) -> LanguageModelClassification:
|
||||
return LanguageModelClassification.SMART_MODEL
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
*,
|
||||
messages: list[ChatMessage],
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
include_os_info: bool,
|
||||
**extras,
|
||||
) -> ChatPrompt:
|
||||
"""Build prompt based on current LATS phase."""
|
||||
# Initialize root node if needed
|
||||
if self.root is None:
|
||||
self.root = LATSNode(
|
||||
state_description=f"Initial state for task: {task}",
|
||||
depth=0,
|
||||
)
|
||||
self.current_node = self.root
|
||||
|
||||
system_prompt = self._build_system_prompt(
|
||||
ai_profile, ai_directives, commands, include_os_info
|
||||
)
|
||||
|
||||
# Add LATS-specific context
|
||||
lats_context = self._build_lats_context(task)
|
||||
|
||||
return ChatPrompt(
|
||||
messages=[
|
||||
ChatMessage.system(system_prompt),
|
||||
ChatMessage.user(f'Task: """{task}"""'),
|
||||
*messages,
|
||||
ChatMessage.system(lats_context),
|
||||
ChatMessage.user(self._get_phase_instruction()),
|
||||
],
|
||||
prefill_response='{\n "thoughts":',
|
||||
functions=commands,
|
||||
)
|
||||
|
||||
def _build_system_prompt(
|
||||
self,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
include_os_info: bool,
|
||||
) -> str:
|
||||
"""Build the system prompt."""
|
||||
intro = self.generate_intro_prompt(ai_profile)
|
||||
body = self.build_body(ai_directives, commands)
|
||||
|
||||
lats_intro = (
|
||||
"\n\n## LATS Strategy\n"
|
||||
"You are using Language Agent Tree Search (LATS) to explore actions.\n"
|
||||
"This involves:\n"
|
||||
"1. Generating candidate actions\n"
|
||||
"2. Evaluating their potential\n"
|
||||
"3. Selecting the most promising path\n"
|
||||
"4. Learning from outcomes to improve future decisions\n"
|
||||
)
|
||||
|
||||
response_format = self._build_response_format()
|
||||
|
||||
parts = intro + [body, lats_intro, response_format]
|
||||
if include_os_info:
|
||||
parts.extend(self.generate_os_info())
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _build_lats_context(self, task: str) -> str:
|
||||
"""Build context about current LATS state."""
|
||||
if self.current_node is None:
|
||||
return ""
|
||||
|
||||
context_parts = [
|
||||
"## Current Search State",
|
||||
f"Phase: {self.phase.value}",
|
||||
f"Tree depth: {self.current_node.depth}",
|
||||
f"Simulations completed: "
|
||||
f"{self.simulation_count}/{self.config.num_simulations}",
|
||||
]
|
||||
|
||||
if self.current_node.reflection:
|
||||
context_parts.append(f"Previous reflection: {self.current_node.reflection}")
|
||||
|
||||
if self.candidate_actions:
|
||||
context_parts.append(
|
||||
f"Candidate actions under consideration: {len(self.candidate_actions)}"
|
||||
)
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def _get_phase_instruction(self) -> str:
|
||||
"""Get instruction for current phase."""
|
||||
if self.phase == LATSPhase.SELECTION:
|
||||
return (
|
||||
"Select the next action to execute. Consider the search tree "
|
||||
"state and choose the most promising action based on UCT scores "
|
||||
"and your reasoning."
|
||||
)
|
||||
elif self.phase == LATSPhase.EXPANSION:
|
||||
return (
|
||||
f"Generate {self.config.num_candidates} candidate actions. "
|
||||
"Each should be a distinct approach to making progress on the task."
|
||||
)
|
||||
elif self.phase == LATSPhase.EVALUATION:
|
||||
return (
|
||||
"Evaluate the outcome of the last action. "
|
||||
"Score progress from 0.0 to 1.0 and reflect on what worked."
|
||||
)
|
||||
else:
|
||||
return "Execute the selected action."
|
||||
|
||||
def _build_response_format(self) -> str:
|
||||
"""Build response format instruction."""
|
||||
response_schema = self.response_schema.model_copy(deep=True)
|
||||
if response_schema.properties and "use_tool" in response_schema.properties:
|
||||
del response_schema.properties["use_tool"]
|
||||
|
||||
return (
|
||||
"## Response Format\n"
|
||||
"Respond with a JSON object containing your thoughts and invoke a tool.\n"
|
||||
f"{response_schema.to_typescript_object_interface('LATSResponse')}"
|
||||
)
|
||||
|
||||
async def expand_with_sub_agents(self, task: str, state: str) -> list[dict]:
|
||||
"""Use sub-agents to generate candidate actions."""
|
||||
if not self.can_spawn_sub_agent():
|
||||
self.logger.warning("Cannot spawn sub-agents for LATS expansion")
|
||||
return []
|
||||
|
||||
expansion_tasks = []
|
||||
for i in range(self.config.num_candidates):
|
||||
sub_task = (
|
||||
f"You are candidate explorer #{i + 1}. "
|
||||
f"Task: {task}\n"
|
||||
f"Current state: {state}\n"
|
||||
f"Propose ONE specific action to make progress. "
|
||||
f"Focus on a unique approach different from other explorers."
|
||||
)
|
||||
expansion_tasks.append(sub_task)
|
||||
|
||||
# Run sub-agents in parallel
|
||||
try:
|
||||
results = await self.run_parallel(
|
||||
expansion_tasks,
|
||||
strategy="one_shot",
|
||||
max_cycles=self.config.sub_agent_max_cycles,
|
||||
)
|
||||
|
||||
candidates = []
|
||||
for i, result in enumerate(results):
|
||||
if result:
|
||||
candidates.append(
|
||||
{
|
||||
"index": i,
|
||||
"suggestion": str(result)[:500],
|
||||
"source": f"sub-agent-{i}",
|
||||
}
|
||||
)
|
||||
|
||||
self.candidate_actions = candidates
|
||||
return candidates
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"LATS expansion failed: {e}")
|
||||
return []
|
||||
|
||||
async def evaluate_with_sub_agent(
|
||||
self, action: str, result: str, task: str
|
||||
) -> tuple[float, str]:
|
||||
"""Use a sub-agent to evaluate an action outcome."""
|
||||
if not self.can_spawn_sub_agent():
|
||||
return 0.5, "Unable to evaluate (no sub-agent available)"
|
||||
|
||||
eval_task = self.config.evaluation_instruction.format(
|
||||
action=action,
|
||||
result=result,
|
||||
task=task,
|
||||
)
|
||||
|
||||
try:
|
||||
eval_result = await self.spawn_and_run(
|
||||
eval_task,
|
||||
strategy="one_shot",
|
||||
max_cycles=5,
|
||||
)
|
||||
|
||||
if eval_result:
|
||||
# Try to parse score and reflection
|
||||
try:
|
||||
parsed = json.loads(str(eval_result))
|
||||
score = float(parsed.get("score", 0.5))
|
||||
reflection = parsed.get("reflection", "")
|
||||
return score, reflection
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Extract score from text
|
||||
score_match = re.search(r"(\d+\.?\d*)", str(eval_result))
|
||||
score = float(score_match.group(1)) if score_match else 0.5
|
||||
return min(score, 1.0), str(eval_result)[:200]
|
||||
|
||||
return 0.5, "Evaluation completed without result"
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"LATS evaluation failed: {e}")
|
||||
return 0.5, f"Evaluation error: {e}"
|
||||
|
||||
def select_node(self) -> LATSNode:
|
||||
"""Select node to expand using UCT."""
|
||||
if self.root is None:
|
||||
raise RuntimeError("LATS tree not initialized")
|
||||
|
||||
node = self.root
|
||||
while node.children and not node.is_terminal:
|
||||
# Select child with highest UCT score
|
||||
node = max(
|
||||
node.children,
|
||||
key=lambda n: n.uct_score(self.config.exploration_weight),
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
def backpropagate(self, node: LATSNode, reward: float) -> None:
|
||||
"""Backpropagate reward through the tree."""
|
||||
current = node
|
||||
while current is not None:
|
||||
current.visits += 1
|
||||
current.value += reward
|
||||
current = current.parent
|
||||
|
||||
def parse_response_content(
|
||||
self,
|
||||
response: AssistantChatMessage,
|
||||
) -> LATSActionProposal:
|
||||
"""Parse the LLM response into a LATS action proposal."""
|
||||
if not response.content:
|
||||
raise InvalidAgentResponseError("Assistant response has no text content")
|
||||
|
||||
self.logger.debug(f"LLM response content:\n{response.content[:500]}")
|
||||
|
||||
assistant_reply_dict = extract_dict_from_json(response.content)
|
||||
|
||||
if not response.tool_calls:
|
||||
raise InvalidAgentResponseError("Assistant did not use a tool")
|
||||
|
||||
assistant_reply_dict["use_tool"] = response.tool_calls[0].function
|
||||
|
||||
parsed_response = LATSActionProposal.model_validate(assistant_reply_dict)
|
||||
parsed_response.raw_message = response.model_copy()
|
||||
|
||||
# Update LATS state based on response
|
||||
self._update_state_from_response(parsed_response)
|
||||
|
||||
return parsed_response
|
||||
|
||||
def _update_state_from_response(self, response: LATSActionProposal) -> None:
|
||||
"""Update LATS state after receiving a response."""
|
||||
if self.current_node is None:
|
||||
return
|
||||
|
||||
# Create child node for the action taken
|
||||
child = LATSNode(
|
||||
state_description=f"After: {response.use_tool.name}",
|
||||
action_taken=response.use_tool.name,
|
||||
parent=self.current_node,
|
||||
depth=self.current_node.depth + 1,
|
||||
)
|
||||
self.current_node.children.append(child)
|
||||
self.current_node = child
|
||||
|
||||
# Advance phase
|
||||
self.simulation_count += 1
|
||||
if self.simulation_count >= self.config.num_simulations:
|
||||
self.phase = LATSPhase.EXECUTION
|
||||
else:
|
||||
self.phase = LATSPhase.SELECTION
|
||||
|
||||
def record_execution_result(
|
||||
self, variable_name: str, result: str, error: Optional[str] = None
|
||||
) -> None:
|
||||
"""Record execution result for backpropagation."""
|
||||
if self.current_node is None:
|
||||
return
|
||||
|
||||
# Simple reward based on success/failure
|
||||
if error:
|
||||
reward = 0.0
|
||||
self.current_node.reflection = f"Action failed: {error}"
|
||||
else:
|
||||
reward = 0.5 # Base reward for successful execution
|
||||
if "success" in result.lower() or "completed" in result.lower():
|
||||
reward = 1.0
|
||||
|
||||
self.current_node.reward = reward
|
||||
self.backpropagate(self.current_node, reward)
|
||||
@@ -0,0 +1,572 @@
|
||||
"""Multi-Agent Debate prompt strategy.
|
||||
|
||||
This strategy implements a multi-agent debate approach where multiple sub-agents
|
||||
propose solutions, debate their merits, and reach a consensus.
|
||||
|
||||
Based on research from:
|
||||
- "Improving Factuality and Reasoning in Language Models through Multiagent Debate"
|
||||
- Google ADK Multi-Agent Patterns
|
||||
|
||||
Key features:
|
||||
- Multiple sub-agents generate independent proposals
|
||||
- Agents critique each other's proposals
|
||||
- Consensus is reached through voting or synthesis
|
||||
- Improves reasoning through diverse perspectives
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from enum import Enum
|
||||
from logging import Logger
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
from forge.llm.prompting import ChatPrompt, LanguageModelClassification
|
||||
from forge.llm.providers.schema import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
from forge.models.action import ActionProposal
|
||||
from forge.models.config import UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.models.utils import ModelWithSummary
|
||||
from forge.utils.exceptions import InvalidAgentResponseError
|
||||
|
||||
from .base import BaseMultiStepPromptStrategy, BasePromptStrategyConfiguration
|
||||
|
||||
|
||||
class DebatePhase(str, Enum):
|
||||
"""Phases of the multi-agent debate."""
|
||||
|
||||
PROPOSAL = "proposal" # Agents generate initial proposals
|
||||
CRITIQUE = "critique" # Agents critique each other's proposals
|
||||
REVISION = "revision" # Agents revise based on critiques
|
||||
CONSENSUS = "consensus" # Synthesize final decision
|
||||
EXECUTION = "execution" # Execute the consensus action
|
||||
|
||||
|
||||
class AgentProposal(BaseModel):
|
||||
"""A proposal from a debate agent."""
|
||||
|
||||
agent_id: str = Field(description="ID of the proposing agent")
|
||||
action_name: str = Field(description="Proposed action name")
|
||||
action_args: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Proposed action arguments"
|
||||
)
|
||||
reasoning: str = Field(description="Reasoning behind the proposal")
|
||||
confidence: float = Field(default=0.5, description="Confidence in proposal (0-1)")
|
||||
|
||||
|
||||
class AgentCritique(BaseModel):
|
||||
"""A critique of another agent's proposal."""
|
||||
|
||||
critic_id: str = Field(description="ID of the critiquing agent")
|
||||
target_agent_id: str = Field(description="ID of the agent being critiqued")
|
||||
strengths: list[str] = Field(default_factory=list)
|
||||
weaknesses: list[str] = Field(default_factory=list)
|
||||
suggestions: list[str] = Field(default_factory=list)
|
||||
score: float = Field(default=0.5, description="Score for the proposal (0-1)")
|
||||
|
||||
|
||||
class DebateState(BaseModel):
|
||||
"""Current state of the debate."""
|
||||
|
||||
proposals: list[AgentProposal] = Field(default_factory=list)
|
||||
critiques: list[AgentCritique] = Field(default_factory=list)
|
||||
revision_count: int = Field(default=0)
|
||||
consensus_reached: bool = Field(default=False)
|
||||
winning_proposal: Optional[AgentProposal] = None
|
||||
|
||||
|
||||
class DebateThoughts(ModelWithSummary):
|
||||
"""Thoughts for debate strategy."""
|
||||
|
||||
observations: str = Field(description="Observations about the debate state")
|
||||
debate_summary: str = Field(description="Summary of the debate so far")
|
||||
reasoning: str = Field(description="Reasoning for the selected action")
|
||||
confidence: float = Field(default=0.5, description="Confidence in decision (0-1)")
|
||||
|
||||
def summary(self) -> str:
|
||||
return self.debate_summary
|
||||
|
||||
|
||||
class DebateActionProposal(ActionProposal):
|
||||
"""Action proposal from debate strategy."""
|
||||
|
||||
thoughts: DebateThoughts # type: ignore
|
||||
|
||||
|
||||
class MultiAgentDebateConfiguration(BasePromptStrategyConfiguration):
|
||||
"""Configuration for multi-agent debate strategy."""
|
||||
|
||||
num_debaters: int = UserConfigurable(default=3)
|
||||
"""Number of debate agents to spawn."""
|
||||
|
||||
num_rounds: int = UserConfigurable(default=2)
|
||||
"""Number of debate rounds (proposal -> critique -> revision)."""
|
||||
|
||||
consensus_threshold: float = UserConfigurable(default=0.7)
|
||||
"""Agreement threshold for consensus (0-1)."""
|
||||
|
||||
use_voting: bool = UserConfigurable(default=True)
|
||||
"""Use voting for consensus vs. synthesis."""
|
||||
|
||||
# Sub-agent configuration
|
||||
enable_sub_agents: bool = UserConfigurable(default=True)
|
||||
max_sub_agents: int = UserConfigurable(default=10)
|
||||
sub_agent_timeout_seconds: int = UserConfigurable(default=180)
|
||||
sub_agent_max_cycles: int = UserConfigurable(default=8)
|
||||
|
||||
DEFAULT_PROPOSAL_INSTRUCTION: str = (
|
||||
"You are Debater #{agent_num} in a multi-agent debate.\n\n"
|
||||
"Task: {task}\n"
|
||||
"Available commands: {commands}\n\n"
|
||||
"Propose ONE specific action to accomplish this task.\n"
|
||||
"Explain your reasoning and why this approach is best.\n\n"
|
||||
"Format your response as:\n"
|
||||
"ACTION: <command_name>\n"
|
||||
"ARGUMENTS: <json arguments>\n"
|
||||
"REASONING: <your reasoning>\n"
|
||||
"CONFIDENCE: <0.0-1.0>"
|
||||
)
|
||||
|
||||
DEFAULT_CRITIQUE_INSTRUCTION: str = (
|
||||
"You are a critic evaluating another agent's proposal.\n\n"
|
||||
"Task: {task}\n"
|
||||
"Proposal being critiqued:\n"
|
||||
"- Action: {action}\n"
|
||||
"- Arguments: {arguments}\n"
|
||||
"- Reasoning: {reasoning}\n\n"
|
||||
"Provide a balanced critique:\n"
|
||||
"STRENGTHS: <what's good about this proposal>\n"
|
||||
"WEAKNESSES: <potential issues or risks>\n"
|
||||
"SUGGESTIONS: <how to improve>\n"
|
||||
"SCORE: <0.0-1.0>"
|
||||
)
|
||||
|
||||
DEFAULT_CONSENSUS_INSTRUCTION: str = (
|
||||
"The debate has concluded. Here are the final proposals and their scores:\n\n"
|
||||
"{proposals_summary}\n\n"
|
||||
"Based on the debate, select the best action to take.\n"
|
||||
"You may combine ideas from multiple proposals if beneficial."
|
||||
)
|
||||
|
||||
proposal_instruction: str = UserConfigurable(default=DEFAULT_PROPOSAL_INSTRUCTION)
|
||||
critique_instruction: str = UserConfigurable(default=DEFAULT_CRITIQUE_INSTRUCTION)
|
||||
consensus_instruction: str = UserConfigurable(default=DEFAULT_CONSENSUS_INSTRUCTION)
|
||||
|
||||
|
||||
class MultiAgentDebateStrategy(BaseMultiStepPromptStrategy):
|
||||
"""Multi-Agent Debate prompt strategy.
|
||||
|
||||
Spawns multiple sub-agents that propose, critique, and debate
|
||||
to reach consensus on the best action.
|
||||
"""
|
||||
|
||||
default_configuration = MultiAgentDebateConfiguration()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
configuration: MultiAgentDebateConfiguration,
|
||||
logger: Logger,
|
||||
):
|
||||
super().__init__(configuration, logger)
|
||||
self.config: MultiAgentDebateConfiguration = configuration
|
||||
self.response_schema = JSONSchema.from_dict(
|
||||
DebateActionProposal.model_json_schema()
|
||||
)
|
||||
|
||||
# Debate state
|
||||
self.debate_state = DebateState()
|
||||
self.phase = DebatePhase.PROPOSAL
|
||||
self.current_round = 0
|
||||
self._commands_str = ""
|
||||
|
||||
@property
|
||||
def llm_classification(self) -> LanguageModelClassification:
|
||||
return LanguageModelClassification.SMART_MODEL
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
*,
|
||||
messages: list[ChatMessage],
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
include_os_info: bool,
|
||||
**extras,
|
||||
) -> ChatPrompt:
|
||||
"""Build prompt based on current debate phase."""
|
||||
# Store commands for sub-agents
|
||||
self._commands_str = ", ".join(cmd.name for cmd in commands)
|
||||
|
||||
system_prompt = self._build_system_prompt(
|
||||
ai_profile, ai_directives, commands, include_os_info
|
||||
)
|
||||
|
||||
debate_context = self._build_debate_context()
|
||||
|
||||
return ChatPrompt(
|
||||
messages=[
|
||||
ChatMessage.system(system_prompt),
|
||||
ChatMessage.user(f'Task: """{task}"""'),
|
||||
*messages,
|
||||
ChatMessage.system(debate_context),
|
||||
ChatMessage.user(self._get_phase_instruction(task)),
|
||||
],
|
||||
prefill_response='{\n "thoughts":',
|
||||
functions=commands,
|
||||
)
|
||||
|
||||
def _build_system_prompt(
|
||||
self,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
include_os_info: bool,
|
||||
) -> str:
|
||||
"""Build the system prompt."""
|
||||
intro = self.generate_intro_prompt(ai_profile)
|
||||
body = self.build_body(ai_directives, commands)
|
||||
|
||||
debate_intro = (
|
||||
"\n\n## Multi-Agent Debate Strategy\n"
|
||||
"You are the coordinator of a multi-agent debate.\n"
|
||||
"Multiple agents will propose and critique solutions.\n"
|
||||
"Your role is to:\n"
|
||||
"1. Orchestrate the debate process\n"
|
||||
"2. Synthesize insights from all agents\n"
|
||||
"3. Select the best action based on consensus\n"
|
||||
)
|
||||
|
||||
response_format = self._build_response_format()
|
||||
|
||||
parts = intro + [body, debate_intro, response_format]
|
||||
if include_os_info:
|
||||
parts.extend(self.generate_os_info())
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _build_debate_context(self) -> str:
|
||||
"""Build context about current debate state."""
|
||||
context_parts = [
|
||||
"## Debate State",
|
||||
f"Phase: {self.phase.value}",
|
||||
f"Round: {self.current_round + 1}/{self.config.num_rounds}",
|
||||
f"Proposals collected: {len(self.debate_state.proposals)}",
|
||||
f"Critiques collected: {len(self.debate_state.critiques)}",
|
||||
]
|
||||
|
||||
if self.debate_state.proposals:
|
||||
context_parts.append("\n### Current Proposals:")
|
||||
for p in self.debate_state.proposals:
|
||||
avg_score = self._get_proposal_score(p.agent_id)
|
||||
context_parts.append(
|
||||
f"- {p.agent_id}: {p.action_name} (score: {avg_score:.2f})"
|
||||
)
|
||||
|
||||
if self.debate_state.consensus_reached:
|
||||
context_parts.append("\n✓ Consensus reached!")
|
||||
if self.debate_state.winning_proposal:
|
||||
wp = self.debate_state.winning_proposal
|
||||
context_parts.append(f"Winner: {wp.action_name}")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
def _get_proposal_score(self, agent_id: str) -> float:
|
||||
"""Get average critique score for a proposal."""
|
||||
scores = [
|
||||
c.score
|
||||
for c in self.debate_state.critiques
|
||||
if c.target_agent_id == agent_id
|
||||
]
|
||||
return sum(scores) / len(scores) if scores else 0.5
|
||||
|
||||
def _get_phase_instruction(self, task: str) -> str:
|
||||
"""Get instruction for current phase."""
|
||||
if self.phase == DebatePhase.PROPOSAL:
|
||||
if not self.debate_state.proposals:
|
||||
return (
|
||||
"The debate is starting. Sub-agents will now generate proposals. "
|
||||
"Invoke 'finish' with reason 'Starting debate' to begin, "
|
||||
"or take a direct action if you're confident."
|
||||
)
|
||||
return "Review the proposals and proceed to critique phase."
|
||||
|
||||
elif self.phase == DebatePhase.CRITIQUE:
|
||||
return "Sub-agents are critiquing proposals. Proceed to synthesis."
|
||||
|
||||
elif self.phase == DebatePhase.CONSENSUS:
|
||||
return self.config.consensus_instruction.format(
|
||||
proposals_summary=self._format_proposals_summary()
|
||||
)
|
||||
|
||||
else: # EXECUTION
|
||||
return "Execute the consensus action."
|
||||
|
||||
def _format_proposals_summary(self) -> str:
|
||||
"""Format proposals for consensus instruction."""
|
||||
lines = []
|
||||
for p in self.debate_state.proposals:
|
||||
score = self._get_proposal_score(p.agent_id)
|
||||
lines.append(
|
||||
f"Proposal from {p.agent_id}:\n"
|
||||
f" Action: {p.action_name}({p.action_args})\n"
|
||||
f" Reasoning: {p.reasoning}\n"
|
||||
f" Score: {score:.2f}"
|
||||
)
|
||||
return "\n\n".join(lines)
|
||||
|
||||
def _build_response_format(self) -> str:
|
||||
"""Build response format instruction."""
|
||||
response_schema = self.response_schema.model_copy(deep=True)
|
||||
if response_schema.properties and "use_tool" in response_schema.properties:
|
||||
del response_schema.properties["use_tool"]
|
||||
|
||||
return (
|
||||
"## Response Format\n"
|
||||
"Respond with a JSON object and invoke a tool.\n"
|
||||
f"{response_schema.to_typescript_object_interface('DebateResponse')}"
|
||||
)
|
||||
|
||||
async def run_proposal_phase(self, task: str) -> list[AgentProposal]:
|
||||
"""Run the proposal phase with sub-agents."""
|
||||
if not self.can_spawn_sub_agent():
|
||||
self.logger.warning("Cannot spawn sub-agents for debate")
|
||||
return []
|
||||
|
||||
proposal_tasks = []
|
||||
for i in range(self.config.num_debaters):
|
||||
sub_task = self.config.proposal_instruction.format(
|
||||
agent_num=i + 1,
|
||||
task=task,
|
||||
commands=self._commands_str,
|
||||
)
|
||||
proposal_tasks.append(sub_task)
|
||||
|
||||
try:
|
||||
results = await self.run_parallel(
|
||||
proposal_tasks,
|
||||
strategy="one_shot",
|
||||
max_cycles=self.config.sub_agent_max_cycles,
|
||||
)
|
||||
|
||||
proposals = []
|
||||
for i, result in enumerate(results):
|
||||
if result:
|
||||
proposal = self._parse_proposal(f"debater-{i + 1}", str(result))
|
||||
if proposal:
|
||||
proposals.append(proposal)
|
||||
|
||||
self.debate_state.proposals = proposals
|
||||
self.phase = DebatePhase.CRITIQUE
|
||||
return proposals
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Proposal phase failed: {e}")
|
||||
return []
|
||||
|
||||
def _parse_proposal(self, agent_id: str, result: str) -> Optional[AgentProposal]:
|
||||
"""Parse a proposal from sub-agent output."""
|
||||
try:
|
||||
# Try to extract structured data
|
||||
action_match = re.search(r"ACTION:\s*(\w+)", result, re.IGNORECASE)
|
||||
args_match = re.search(r"ARGUMENTS:\s*(\{.*?\})", result, re.DOTALL)
|
||||
reasoning_match = re.search(
|
||||
r"REASONING:\s*(.+?)(?=CONFIDENCE:|$)",
|
||||
result,
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
confidence_match = re.search(r"CONFIDENCE:\s*([\d.]+)", result)
|
||||
|
||||
if action_match:
|
||||
action_name = action_match.group(1)
|
||||
action_args = {}
|
||||
if args_match:
|
||||
try:
|
||||
action_args = json.loads(args_match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
reasoning = (
|
||||
reasoning_match.group(1).strip()
|
||||
if reasoning_match
|
||||
else result[:200]
|
||||
)
|
||||
confidence = (
|
||||
float(confidence_match.group(1)) if confidence_match else 0.5
|
||||
)
|
||||
|
||||
return AgentProposal(
|
||||
agent_id=agent_id,
|
||||
action_name=action_name,
|
||||
action_args=action_args,
|
||||
reasoning=reasoning,
|
||||
confidence=min(confidence, 1.0),
|
||||
)
|
||||
|
||||
# Fallback: try to extract any useful info
|
||||
return AgentProposal(
|
||||
agent_id=agent_id,
|
||||
action_name="unknown",
|
||||
action_args={},
|
||||
reasoning=result[:300],
|
||||
confidence=0.3,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to parse proposal: {e}")
|
||||
return None
|
||||
|
||||
async def run_critique_phase(self, task: str) -> list[AgentCritique]:
|
||||
"""Run the critique phase with sub-agents."""
|
||||
if not self.can_spawn_sub_agent() or not self.debate_state.proposals:
|
||||
return []
|
||||
|
||||
critique_tasks = []
|
||||
for i, proposal in enumerate(self.debate_state.proposals):
|
||||
# Each other debater critiques this proposal
|
||||
for j in range(self.config.num_debaters):
|
||||
if j == i:
|
||||
continue # Don't critique own proposal
|
||||
|
||||
sub_task = self.config.critique_instruction.format(
|
||||
task=task,
|
||||
action=proposal.action_name,
|
||||
arguments=json.dumps(proposal.action_args),
|
||||
reasoning=proposal.reasoning,
|
||||
)
|
||||
critique_tasks.append((f"critic-{j + 1}", proposal.agent_id, sub_task))
|
||||
|
||||
try:
|
||||
# Run critiques (limit parallelism)
|
||||
critiques = []
|
||||
for critic_id, target_id, sub_task in critique_tasks:
|
||||
result = await self.spawn_and_run(
|
||||
sub_task,
|
||||
strategy="one_shot",
|
||||
max_cycles=5,
|
||||
)
|
||||
if result:
|
||||
critique = self._parse_critique(critic_id, target_id, str(result))
|
||||
if critique:
|
||||
critiques.append(critique)
|
||||
|
||||
self.debate_state.critiques = critiques
|
||||
self.current_round += 1
|
||||
|
||||
if self.current_round >= self.config.num_rounds:
|
||||
self.phase = DebatePhase.CONSENSUS
|
||||
else:
|
||||
self.phase = DebatePhase.PROPOSAL # Another round
|
||||
|
||||
return critiques
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Critique phase failed: {e}")
|
||||
return []
|
||||
|
||||
def _parse_critique(
|
||||
self, critic_id: str, target_id: str, result: str
|
||||
) -> Optional[AgentCritique]:
|
||||
"""Parse a critique from sub-agent output."""
|
||||
try:
|
||||
strengths = re.findall(
|
||||
r"STRENGTHS?:\s*(.+?)(?=WEAKNESSES?:|$)",
|
||||
result,
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
weaknesses = re.findall(
|
||||
r"WEAKNESSES?:\s*(.+?)(?=SUGGESTIONS?:|$)",
|
||||
result,
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
suggestions = re.findall(
|
||||
r"SUGGESTIONS?:\s*(.+?)(?=SCORE:|$)", result, re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
score_match = re.search(r"SCORE:\s*([\d.]+)", result)
|
||||
|
||||
return AgentCritique(
|
||||
critic_id=critic_id,
|
||||
target_agent_id=target_id,
|
||||
strengths=[s.strip() for s in strengths] if strengths else [],
|
||||
weaknesses=[w.strip() for w in weaknesses] if weaknesses else [],
|
||||
suggestions=[s.strip() for s in suggestions] if suggestions else [],
|
||||
score=float(score_match.group(1)) if score_match else 0.5,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to parse critique: {e}")
|
||||
return None
|
||||
|
||||
def determine_consensus(self) -> Optional[AgentProposal]:
|
||||
"""Determine consensus from proposals and critiques."""
|
||||
if not self.debate_state.proposals:
|
||||
return None
|
||||
|
||||
# Score proposals by average critique score
|
||||
scored_proposals = []
|
||||
for proposal in self.debate_state.proposals:
|
||||
score = self._get_proposal_score(proposal.agent_id)
|
||||
# Also factor in original confidence
|
||||
combined_score = 0.7 * score + 0.3 * proposal.confidence
|
||||
scored_proposals.append((proposal, combined_score))
|
||||
|
||||
# Sort by score
|
||||
scored_proposals.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Check if top proposal meets consensus threshold
|
||||
if scored_proposals:
|
||||
best_proposal, best_score = scored_proposals[0]
|
||||
if best_score >= self.config.consensus_threshold:
|
||||
self.debate_state.consensus_reached = True
|
||||
self.debate_state.winning_proposal = best_proposal
|
||||
self.phase = DebatePhase.EXECUTION
|
||||
return best_proposal
|
||||
|
||||
# If no clear winner, return highest scored anyway
|
||||
if scored_proposals:
|
||||
self.debate_state.winning_proposal = scored_proposals[0][0]
|
||||
return scored_proposals[0][0]
|
||||
|
||||
return None
|
||||
|
||||
def parse_response_content(
|
||||
self,
|
||||
response: AssistantChatMessage,
|
||||
) -> DebateActionProposal:
|
||||
"""Parse the LLM response into a debate action proposal."""
|
||||
if not response.content:
|
||||
raise InvalidAgentResponseError("Assistant response has no text content")
|
||||
|
||||
self.logger.debug(f"LLM response content:\n{response.content[:500]}")
|
||||
|
||||
assistant_reply_dict = extract_dict_from_json(response.content)
|
||||
|
||||
if not response.tool_calls:
|
||||
raise InvalidAgentResponseError("Assistant did not use a tool")
|
||||
|
||||
assistant_reply_dict["use_tool"] = response.tool_calls[0].function
|
||||
|
||||
parsed_response = DebateActionProposal.model_validate(assistant_reply_dict)
|
||||
parsed_response.raw_message = response.model_copy()
|
||||
|
||||
return parsed_response
|
||||
|
||||
def record_execution_result(
|
||||
self, variable_name: str, result: str, error: Optional[str] = None
|
||||
) -> None:
|
||||
"""Record execution result."""
|
||||
# Reset for next decision if needed
|
||||
if self.phase == DebatePhase.EXECUTION:
|
||||
self.debate_state = DebateState()
|
||||
self.phase = DebatePhase.PROPOSAL
|
||||
self.current_round = 0
|
||||
@@ -6,6 +6,8 @@ import re
|
||||
from logging import Logger
|
||||
|
||||
import distro
|
||||
from pydantic import Field
|
||||
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
@@ -21,7 +23,6 @@ from forge.models.config import SystemConfiguration, UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.models.utils import ModelWithSummary
|
||||
from forge.utils.exceptions import InvalidAgentResponseError
|
||||
from pydantic import Field
|
||||
|
||||
_RESPONSE_INTERFACE_NAME = "AssistantResponse"
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ from enum import Enum
|
||||
from logging import Logger
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
@@ -39,7 +41,6 @@ from forge.models.config import UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.models.utils import ModelWithSummary
|
||||
from forge.utils.exceptions import InvalidAgentResponseError
|
||||
from pydantic import Field
|
||||
|
||||
from .base import (
|
||||
BaseMultiStepPromptStrategy,
|
||||
|
||||
@@ -27,6 +27,8 @@ from enum import Enum
|
||||
from logging import Logger
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
@@ -41,7 +43,6 @@ from forge.models.config import UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.models.utils import ModelWithSummary
|
||||
from forge.utils.exceptions import InvalidAgentResponseError
|
||||
from pydantic import Field
|
||||
|
||||
from .base import (
|
||||
BaseMultiStepPromptStrategy,
|
||||
|
||||
@@ -24,6 +24,8 @@ from enum import Enum
|
||||
from logging import Logger
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
@@ -39,7 +41,6 @@ from forge.models.config import UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.models.utils import ModelWithSummary
|
||||
from forge.utils.exceptions import InvalidAgentResponseError
|
||||
from pydantic import Field
|
||||
|
||||
from .base import (
|
||||
BaseMultiStepPromptStrategy,
|
||||
@@ -181,7 +182,14 @@ class ReWOOPromptConfiguration(BasePromptStrategyConfiguration):
|
||||
"""Configuration for ReWOO prompt strategy."""
|
||||
|
||||
DEFAULT_PLANNER_INSTRUCTION: str = (
|
||||
"Create a complete plan to accomplish the task. For each step:\n"
|
||||
"Create a complete plan to FULLY ACCOMPLISH the task. Your plan must include "
|
||||
"all steps needed to produce the final deliverable - not just exploration.\n\n"
|
||||
"IMPORTANT:\n"
|
||||
"- Do NOT end with exploration, research, or todo/planning steps\n"
|
||||
"- Your plan must result in the actual task being COMPLETE\n"
|
||||
"- Include steps that create/modify files, write code, or produce output\n"
|
||||
"- The final steps should verify the task is done, not plan future work\n\n"
|
||||
"For each step:\n"
|
||||
"1. Write your reasoning (Plan:)\n"
|
||||
"2. Specify the tool to use and its arguments\n"
|
||||
"3. Assign a variable name (#E1, #E2, etc.) to store the result\n"
|
||||
@@ -194,10 +202,14 @@ class ReWOOPromptConfiguration(BasePromptStrategyConfiguration):
|
||||
|
||||
# Paper-style planner instruction (uses bracket syntax like the original paper)
|
||||
DEFAULT_PAPER_PLANNER_INSTRUCTION: str = (
|
||||
"For the following task, make plans that can solve the problem step by step. "
|
||||
"For the following task, make plans that can FULLY SOLVE the problem "
|
||||
"step by step. "
|
||||
"For each plan, indicate which external tool together with tool input to "
|
||||
"retrieve evidence. You can store the evidence into a variable #E[n] that "
|
||||
"can be called by later tools.\n\n"
|
||||
"IMPORTANT: Your plan must COMPLETE the task, not just explore or prepare. "
|
||||
"Do not end with research or planning steps - include all actions needed "
|
||||
"to produce the final deliverable.\n\n"
|
||||
"Tools can be one of the following:\n"
|
||||
"{available_tools}\n\n"
|
||||
"Format:\n"
|
||||
@@ -211,11 +223,16 @@ class ReWOOPromptConfiguration(BasePromptStrategyConfiguration):
|
||||
|
||||
DEFAULT_SYNTHESIZER_INSTRUCTION: str = (
|
||||
"You have executed the following plan and received these results.\n"
|
||||
"Analyze the results and provide a final response to the original task.\n\n"
|
||||
"Analyze the results and determine if the ORIGINAL TASK has been "
|
||||
"accomplished.\n\n"
|
||||
"Plan and Results:\n{plan_with_results}\n\n"
|
||||
"Original Task: {task}\n\n"
|
||||
"Provide your synthesis in the required JSON format, then use the "
|
||||
"appropriate command to complete the task or report findings."
|
||||
"IMPORTANT: The task is only complete if you have PRODUCED THE DELIVERABLE "
|
||||
"(created/modified files, written code, generated output, etc). If you only "
|
||||
"explored or planned, the task is NOT complete.\n\n"
|
||||
"If the task is truly complete, call `finish` with your final answer. "
|
||||
"If the task is NOT complete (only explored/planned), you must call other "
|
||||
"commands to actually complete the work."
|
||||
)
|
||||
|
||||
# Paper-style synthesizer instruction (Solver module from paper)
|
||||
@@ -226,7 +243,9 @@ class ReWOOPromptConfiguration(BasePromptStrategyConfiguration):
|
||||
"with caution.\n\n"
|
||||
"Task: {task}\n\n"
|
||||
"Plans and Evidences:\n{plan_with_results}\n\n"
|
||||
"Now solve the task. Provide your answer with clear reasoning."
|
||||
"IMPORTANT: The task is only complete if you have PRODUCED THE DELIVERABLE. "
|
||||
"If you only explored or planned, call commands to complete the actual work. "
|
||||
"If the task is truly complete, call `finish` with your final answer."
|
||||
)
|
||||
|
||||
planner_instruction: str = UserConfigurable(default=DEFAULT_PLANNER_INSTRUCTION)
|
||||
@@ -630,12 +649,52 @@ class ReWOOPromptStrategy(BaseMultiStepPromptStrategy):
|
||||
return arguments
|
||||
|
||||
def _parse_tool_arguments(self, args_str: str) -> dict[str, Any]:
|
||||
"""Parse tool arguments from a string like 'arg1="value1", arg2=123'."""
|
||||
"""Parse tool arguments from a string like 'arg1="value1", arg2=123'.
|
||||
|
||||
Supports:
|
||||
- String values: arg="value"
|
||||
- Numbers: arg=123 or arg=1.5
|
||||
- Variable references: arg=#E1
|
||||
- JSON arrays: arg=[{...}]
|
||||
- JSON objects: arg={...}
|
||||
- Booleans: arg=true/false
|
||||
"""
|
||||
arguments: dict[str, Any] = {}
|
||||
|
||||
# Simple pattern matching for key=value pairs
|
||||
# This handles strings, numbers, and variable references
|
||||
arg_pattern = re.compile(r'(\w+)\s*=\s*(?:"([^"]*)"|(#E\d+)|(\d+\.?\d*))')
|
||||
# Try to parse as JSON-like structure
|
||||
# Handle both JSON format ("key": value) and Python format (key=value)
|
||||
try:
|
||||
json_str = args_str.strip()
|
||||
|
||||
# Convert Python-style key=value to JSON-style "key": value
|
||||
# Match: word= at start or after comma, followed by any value
|
||||
json_str = re.sub(r"(?:^|,\s*)(\w+)\s*=\s*", r'"\1": ', json_str)
|
||||
|
||||
# Wrap in braces
|
||||
json_str = "{" + json_str + "}"
|
||||
|
||||
# Fix common Python-isms: single quotes -> double, True/False -> true/false
|
||||
json_str = json_str.replace("'", '"')
|
||||
json_str = re.sub(r"\bTrue\b", "true", json_str)
|
||||
json_str = re.sub(r"\bFalse\b", "false", json_str)
|
||||
json_str = re.sub(r"\bNone\b", "null", json_str)
|
||||
|
||||
parsed = json.loads(json_str)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Fall back to regex-based parsing for simpler cases
|
||||
# Pattern for key=value where value can be:
|
||||
# - quoted string: "..."
|
||||
# - variable reference: #E1
|
||||
# - number: 123 or 1.5
|
||||
# - boolean: true/false
|
||||
arg_pattern = re.compile(
|
||||
r'(\w+)\s*=\s*(?:"([^"]*)"|(#E\d+)|(\d+\.?\d*)|(true|false))',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
for match in arg_pattern.finditer(args_str):
|
||||
key = match.group(1)
|
||||
@@ -646,6 +705,8 @@ class ReWOOPromptStrategy(BaseMultiStepPromptStrategy):
|
||||
elif match.group(4) is not None: # Number
|
||||
num_str = match.group(4)
|
||||
arguments[key] = float(num_str) if "." in num_str else int(num_str)
|
||||
elif match.group(5) is not None: # Boolean
|
||||
arguments[key] = match.group(5).lower() == "true"
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ from enum import Enum
|
||||
from logging import Logger
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
@@ -42,7 +44,6 @@ from forge.models.config import UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.models.utils import ModelWithSummary
|
||||
from forge.utils.exceptions import InvalidAgentResponseError
|
||||
from pydantic import Field
|
||||
|
||||
from .base import BaseMultiStepPromptStrategy, BasePromptStrategyConfiguration, Thought
|
||||
|
||||
|
||||
@@ -10,6 +10,14 @@ from fastapi import APIRouter, FastAPI, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from hypercorn.asyncio import serve as hypercorn_serve
|
||||
from hypercorn.config import Config as HypercornConfig
|
||||
from sentry_sdk import set_user
|
||||
|
||||
from autogpt.agent_factory.configurators import configure_agent_with_state, create_agent
|
||||
from autogpt.agents.agent_manager import AgentManager
|
||||
from autogpt.app.config import AppConfig
|
||||
from autogpt.app.utils import is_port_free
|
||||
from forge.agent_protocol.api_router import base_router
|
||||
from forge.agent_protocol.database import AgentDB
|
||||
from forge.agent_protocol.middlewares import AgentMiddleware
|
||||
@@ -28,14 +36,6 @@ from forge.llm.providers import ModelProviderBudget, MultiProvider
|
||||
from forge.models.action import ActionErrorResult, ActionSuccessResult
|
||||
from forge.utils.const import ASK_COMMAND, FINISH_COMMAND
|
||||
from forge.utils.exceptions import AgentFinished, NotFoundError
|
||||
from hypercorn.asyncio import serve as hypercorn_serve
|
||||
from hypercorn.config import Config as HypercornConfig
|
||||
from sentry_sdk import set_user
|
||||
|
||||
from autogpt.agent_factory.configurators import configure_agent_with_state, create_agent
|
||||
from autogpt.agents.agent_manager import AgentManager
|
||||
from autogpt.app.config import AppConfig
|
||||
from autogpt.app.utils import is_port_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
from forge.logging.config import LogFormatName
|
||||
|
||||
from .telemetry import setup_telemetry
|
||||
|
||||
@@ -8,12 +8,13 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from forge.config.base import BaseConfig
|
||||
from forge.llm.providers import ModelName
|
||||
from forge.llm.providers.openai import OpenAICredentials, OpenAIModelName
|
||||
from forge.logging.config import LoggingConfig
|
||||
from forge.models.config import Configurable, UserConfigurable
|
||||
from pydantic import SecretStr
|
||||
|
||||
# Type alias for prompt strategy options
|
||||
PromptStrategyName = Literal[
|
||||
@@ -149,9 +150,10 @@ async def assert_config_has_required_llm_api_keys(config: AppConfig) -> None:
|
||||
"""
|
||||
Check if API keys (if required) are set for the configured SMART_LLM and FAST_LLM.
|
||||
"""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from forge.llm.providers.anthropic import AnthropicModelName
|
||||
from forge.llm.providers.groq import GroqModelName
|
||||
from pydantic import ValidationError
|
||||
|
||||
if set((config.smart_llm, config.fast_llm)).intersection(AnthropicModelName):
|
||||
from forge.llm.providers.anthropic import AnthropicCredentials
|
||||
@@ -181,9 +183,10 @@ async def assert_config_has_required_llm_api_keys(config: AppConfig) -> None:
|
||||
)
|
||||
|
||||
if set((config.smart_llm, config.fast_llm)).intersection(GroqModelName):
|
||||
from forge.llm.providers.groq import GroqProvider
|
||||
from groq import AuthenticationError
|
||||
|
||||
from forge.llm.providers.groq import GroqProvider
|
||||
|
||||
try:
|
||||
groq = GroqProvider()
|
||||
await groq.get_available_models()
|
||||
@@ -206,9 +209,10 @@ async def assert_config_has_required_llm_api_keys(config: AppConfig) -> None:
|
||||
raise ValueError("Groq is unavailable: invalid API key") from e
|
||||
|
||||
if set((config.smart_llm, config.fast_llm)).intersection(OpenAIModelName):
|
||||
from forge.llm.providers.openai import OpenAIProvider
|
||||
from openai import AuthenticationError
|
||||
|
||||
from forge.llm.providers.openai import OpenAIProvider
|
||||
|
||||
try:
|
||||
openai = OpenAIProvider()
|
||||
await openai.get_available_models()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Configurator module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Literal, Optional
|
||||
|
||||
import click
|
||||
from forge.llm.providers import ModelName, MultiProvider
|
||||
|
||||
from autogpt.app.config import GPT_3_MODEL, AppConfig
|
||||
from forge.llm.providers import ModelName, MultiProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -14,6 +14,15 @@ from types import FrameType
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from colorama import Fore, Style
|
||||
|
||||
from autogpt.agent_factory.configurators import configure_agent_with_state, create_agent
|
||||
from autogpt.agents.agent_manager import AgentManager
|
||||
from autogpt.agents.prompt_strategies.one_shot import AssistantThoughts
|
||||
from autogpt.app.config import (
|
||||
AppConfig,
|
||||
ConfigBuilder,
|
||||
assert_config_has_required_llm_api_keys,
|
||||
)
|
||||
from forge.agent_protocol.database import AgentDB
|
||||
from forge.components.code_executor.code_executor import (
|
||||
is_docker_available,
|
||||
@@ -36,15 +45,6 @@ from forge.utils.exceptions import (
|
||||
InvalidAgentResponseError,
|
||||
)
|
||||
|
||||
from autogpt.agent_factory.configurators import configure_agent_with_state, create_agent
|
||||
from autogpt.agents.agent_manager import AgentManager
|
||||
from autogpt.agents.prompt_strategies.one_shot import AssistantThoughts
|
||||
from autogpt.app.config import (
|
||||
AppConfig,
|
||||
ConfigBuilder,
|
||||
assert_config_has_required_llm_api_keys,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.app.ui.protocol import UIProvider
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""Set up the AI and its goals"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from autogpt.app.config import AppConfig
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.logging.utils import print_attribute
|
||||
|
||||
from autogpt.app.config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Optional
|
||||
|
||||
import click
|
||||
from colorama import Fore, Style
|
||||
|
||||
from forge.logging.utils import print_attribute
|
||||
from forge.permissions import ApprovalScope
|
||||
|
||||
@@ -138,9 +139,8 @@ class TerminalUIProvider(UIProvider):
|
||||
thoughts: The agent's thoughts (string or structured).
|
||||
speak_mode: Whether to use text-to-speech.
|
||||
"""
|
||||
from forge.models.utils import ModelWithSummary
|
||||
|
||||
from autogpt.agents.prompt_strategies.one_shot import AssistantThoughts
|
||||
from forge.models.utils import ModelWithSummary
|
||||
|
||||
thoughts_text = self._remove_ansi_escape(
|
||||
thoughts.text
|
||||
|
||||
@@ -5,11 +5,11 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from forge.llm.providers import ChatMessage, MultiProvider
|
||||
from forge.llm.providers.anthropic import AnthropicModelName
|
||||
from git import Repo, TagReference
|
||||
|
||||
from autogpt.app.utils import coroutine
|
||||
from forge.llm.providers import ChatMessage, MultiProvider
|
||||
from forge.llm.providers.anthropic import AnthropicModelName
|
||||
|
||||
|
||||
@click.command()
|
||||
@@ -132,6 +132,7 @@ Do not mention the changes in the example when writing your release notes!
|
||||
|
||||
if __name__ == "__main__":
|
||||
import dotenv
|
||||
|
||||
from forge.logging.config import configure_logging
|
||||
|
||||
configure_logging(debug=True)
|
||||
|
||||
@@ -5,6 +5,10 @@ import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.app.config import AppConfig, ConfigBuilder
|
||||
from autogpt.app.main import _configure_llm_provider
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.file_storage.local import (
|
||||
FileStorage,
|
||||
@@ -14,10 +18,6 @@ from forge.file_storage.local import (
|
||||
from forge.llm.providers import MultiProvider
|
||||
from forge.logging.config import configure_logging
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.app.config import AppConfig, ConfigBuilder
|
||||
from autogpt.app.main import _configure_llm_provider
|
||||
|
||||
pytest_plugins = [
|
||||
"tests.integration.agent_factory",
|
||||
]
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.file_storage import FileStorageBackendName, get_storage
|
||||
from forge.llm.providers import MultiProvider
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.app.config import AppConfig
|
||||
from forge.config.ai_profile import AIProfile
|
||||
from forge.file_storage import FileStorageBackendName, get_storage
|
||||
from forge.llm.providers import MultiProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import pytest
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
|
||||
from autogpt.app.config import AppConfig
|
||||
from autogpt.app.setup import (
|
||||
apply_overrides_to_ai_settings,
|
||||
interactively_revise_ai_settings,
|
||||
)
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
from forge.config.ai_profile import AIProfile
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -5,8 +5,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
from git import InvalidGitRepositoryError
|
||||
from tests.utils import skip_in_ci
|
||||
|
||||
import autogpt.app.utils
|
||||
from autogpt.app.utils import (
|
||||
@@ -15,7 +15,7 @@ from autogpt.app.utils import (
|
||||
get_latest_bulletin,
|
||||
set_env_config_value,
|
||||
)
|
||||
from tests.utils import skip_in_ci
|
||||
from forge.json.parsing import extract_dict_from_json
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
24
classic/reports/.benchmark_state.json
Normal file
24
classic/reports/.benchmark_state.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"session_id": "2026-01-20T00:43:46.908145",
|
||||
"started_at": "2026-01-20T00:43:46.908149",
|
||||
"completed_runs": {
|
||||
"rewoo/claude:ReadFile:1": {
|
||||
"config_name": "rewoo/claude",
|
||||
"challenge_name": "ReadFile",
|
||||
"attempt": 1,
|
||||
"success": false,
|
||||
"cost": 0.0,
|
||||
"n_steps": 8,
|
||||
"run_time_seconds": 60.015603,
|
||||
"error_message": "Challenge timed out",
|
||||
"completed_at": "2026-01-20T00:44:46.936504"
|
||||
}
|
||||
},
|
||||
"strategies": [
|
||||
"rewoo"
|
||||
],
|
||||
"models": [
|
||||
"claude"
|
||||
],
|
||||
"attempts": 1
|
||||
}
|
||||
38
classic/reports/20260120T004315_one_shot_claude/report.json
Normal file
38
classic/reports/20260120T004315_one_shot_claude/report.json
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"command": "direct_benchmark run --config one_shot/claude",
|
||||
"completion_time": "2026-01-20T06:43:37.023065+00:00",
|
||||
"benchmark_start_time": "2026-01-20T06:43:15.998389+00:00",
|
||||
"metrics": {
|
||||
"run_time": "21.02 seconds",
|
||||
"highest_difficulty": "interface",
|
||||
"total_cost": 0.0715326
|
||||
},
|
||||
"config": {
|
||||
"config_name": "one_shot/claude"
|
||||
},
|
||||
"tests": {
|
||||
"ReadFile": {
|
||||
"category": [],
|
||||
"difficulty": null,
|
||||
"data_path": "",
|
||||
"description": "",
|
||||
"task": "",
|
||||
"answer": "",
|
||||
"metrics": {
|
||||
"attempted": true,
|
||||
"is_regression": false,
|
||||
"success_percentage": 100.0
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"success": true,
|
||||
"run_time": "20.995 seconds",
|
||||
"fail_reason": null,
|
||||
"reached_cutoff": false,
|
||||
"n_steps": 2,
|
||||
"cost": 0.0715326
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
38
classic/reports/20260120T004346_rewoo_claude/report.json
Normal file
38
classic/reports/20260120T004346_rewoo_claude/report.json
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"command": "direct_benchmark run --config rewoo/claude",
|
||||
"completion_time": "2026-01-20T06:44:46.937190+00:00",
|
||||
"benchmark_start_time": "2026-01-20T06:43:46.908109+00:00",
|
||||
"metrics": {
|
||||
"run_time": "60.03 seconds",
|
||||
"highest_difficulty": "interface",
|
||||
"total_cost": 0.0
|
||||
},
|
||||
"config": {
|
||||
"config_name": "rewoo/claude"
|
||||
},
|
||||
"tests": {
|
||||
"ReadFile": {
|
||||
"category": [],
|
||||
"difficulty": null,
|
||||
"data_path": "",
|
||||
"description": "",
|
||||
"task": "",
|
||||
"answer": "",
|
||||
"metrics": {
|
||||
"attempted": true,
|
||||
"is_regression": false,
|
||||
"success_percentage": 0.0
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"success": false,
|
||||
"run_time": "60.016 seconds",
|
||||
"fail_reason": "Challenge timed out",
|
||||
"reached_cutoff": true,
|
||||
"n_steps": 8,
|
||||
"cost": 0.0
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user