fix(classic): resolve all pyright type errors

- Add missing strategies (lats, multi_agent_debate) to PromptStrategyName
- Fix method override signatures for reasoning_effort parameter
- Fix Pydantic Field() overload issues with helper function
- Fix BeautifulSoup Tag type narrowing in web_fetch.py
- Fix Optional member access in playwright_browser.py and rewoo.py
- Convert hasattr patterns to getattr for proper type narrowing
- Add proper type casts for Literal types
- Fix file storage path type conversions
- Exclude legacy challenges/ from pyright checking

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Nicholas Tindle
2026-01-20 10:41:53 -06:00
parent 7d6375f59c
commit a4d7b0142f
18 changed files with 306 additions and 126 deletions

View File

@@ -127,15 +127,21 @@ class FailureAnalyzer:
def __init__(self, reports_dir: Path, use_llm: bool = True):
self.reports_dir = reports_dir
self.use_llm = use_llm
self.console = Console() if RICH_AVAILABLE else None
self._console_instance = Console() if RICH_AVAILABLE else None
self.strategies: dict[str, StrategyAnalysis] = {}
self.test_comparison: dict[str, dict[str, TestResult]] = defaultdict(dict)
self._llm_provider = None
def _print(self, *args, **kwargs):
@property
def console(self) -> Any:
"""Get console instance (only call when RICH_AVAILABLE is True)."""
assert self._console_instance is not None
return self._console_instance
def _print(self, *args: Any, **kwargs: Any) -> None:
"""Print with Rich if available, otherwise standard print."""
if self.console:
self.console.print(*args, **kwargs)
if self._console_instance:
self._console_instance.print(*args, **kwargs)
else:
print(*args, **kwargs)

View File

@@ -3,13 +3,19 @@
import os
import sys
from pathlib import Path
from typing import Optional
from typing import Optional, cast
import click
from .challenge_loader import find_challenges_dir
from .harness import BenchmarkHarness
from .models import MODEL_PRESETS, STRATEGIES, BenchmarkConfig, HarnessConfig
from .models import (
MODEL_PRESETS,
STRATEGIES,
BenchmarkConfig,
HarnessConfig,
StrategyName,
)
from .ui import console
@@ -272,7 +278,7 @@ def run(
model = MODEL_PRESETS[model_name]
configs.append(
BenchmarkConfig(
strategy=strategy,
strategy=cast(StrategyName, strategy),
model=model,
max_steps=max_steps,
timeout_seconds=timeout,

View File

@@ -180,9 +180,7 @@ class BenchmarkHarness:
self.state_manager.mark_completed(progress.result, attempt)
# Create step callback if UI supports it
step_callback = None
if hasattr(ui, "log_step"):
step_callback = ui.log_step
step_callback = getattr(ui, "log_step", None)
# Create skip function for resume functionality
def should_skip(config_name: str, challenge_name: str, attempt: int) -> bool:

View File

@@ -5,7 +5,10 @@ import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, cast
if TYPE_CHECKING:
from forge.llm.providers import ModelName
from autogpt.agent_factory.configurators import create_agent
from autogpt.agents.agent import Agent
@@ -159,9 +162,9 @@ class AgentRunner:
# Apply model and strategy configuration
if self.config.model.smart_llm:
app_config.smart_llm = self.config.model.smart_llm
app_config.smart_llm = cast("ModelName", self.config.model.smart_llm)
if self.config.model.fast_llm:
app_config.fast_llm = self.config.model.fast_llm
app_config.fast_llm = cast("ModelName", self.config.model.fast_llm)
app_config.prompt_strategy = self.config.strategy
app_config.noninteractive_mode = True
app_config.continuous_mode = True
@@ -253,9 +256,7 @@ class AgentRunner:
cumulative_cost = self._llm_provider.get_incurred_cost()
# Get result info
result_str = str(
result.outputs if hasattr(result, "outputs") else result
)
result_str = str(getattr(result, "outputs", result))
is_error = hasattr(result, "status") and result.status == "error"
# Record step

View File

@@ -6,7 +6,7 @@ from datetime import datetime
from typing import Optional
from rich.columns import Columns
from rich.console import Console, Group, RenderableType
from rich.console import Console, ConsoleOptions, Group, RenderResult
from rich.panel import Panel
from rich.progress import (
BarColumn,
@@ -70,9 +70,9 @@ class BenchmarkUI:
configure_logging_for_benchmark()
# Track state - use run_key (config:challenge) for uniqueness
self.active_runs: dict[
str, tuple[str, str]
] = {} # run_key -> (config_name, challenge_name)
self.active_runs: dict[str, tuple[str, str]] = (
{}
) # run_key -> (config_name, challenge_name)
self.active_steps: dict[str, str] = {} # run_key -> current step info
self.completed: list[ChallengeResult] = []
self.results_by_config: dict[str, list[ChallengeResult]] = {}
@@ -414,7 +414,9 @@ class BenchmarkUI:
self.render_recent_completions(),
)
def __rich_console__(self, console: Console, options) -> RenderableType:
def __rich_console__(
self, console: Console, options: ConsoleOptions
) -> RenderResult:
"""Support for Rich Live display - called on each refresh."""
yield self.render_live_display()

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import logging
from pathlib import Path
from typing import Iterator, Literal, Optional
from typing import Any, Iterator, Literal, Optional
from bs4 import BeautifulSoup
from pydantic import BaseModel, SecretStr
@@ -186,7 +186,7 @@ class WebPlaywrightComponent(
"""Launch a local browser instance."""
browser_launcher = getattr(self._playwright, self.config.browser_type)
launch_args = {
launch_args: dict[str, Any] = {
"headless": self.config.headless,
}
@@ -219,6 +219,7 @@ class WebPlaywrightComponent(
async def _create_context(self):
"""Create a browser context with configured settings."""
assert self._browser is not None, "Browser not initialized"
context = await self._browser.new_context(
user_agent=self.config.user_agent,
viewport={"width": 1920, "height": 1080},
@@ -296,6 +297,7 @@ class WebPlaywrightComponent(
async def _open_page(self, url: str):
"""Open a new page and navigate to URL with smart waiting."""
await self._ensure_browser()
assert self._context is not None, "Browser context not initialized"
page = await self._context.new_page()
try:

View File

@@ -102,9 +102,9 @@ class WebSearchConfiguration(BaseModel):
)
# Legacy aliases for backwards compatibility
duckduckgo_max_attempts: int = 3 # Now used as max backend attempts
duckduckgo_backend: Literal[
"api", "html", "lite"
] = "api" # Ignored, use ddgs_backend
duckduckgo_backend: Literal["api", "html", "lite"] = (
"api" # Ignored, use ddgs_backend
)
class WebSearchComponent(
@@ -123,7 +123,7 @@ class WebSearchComponent(
def __init__(self, config: Optional[WebSearchConfiguration] = None):
ConfigurableComponent.__init__(self, config)
self._ddgs_client: Optional["DDGS"] = None
self._ddgs_client: "Optional[DDGS]" = None # type: ignore[type-arg]
self._log_provider_status()
def _log_provider_status(self) -> None:
@@ -145,7 +145,7 @@ class WebSearchComponent(
)
@property
def ddgs_client(self) -> "DDGS":
def ddgs_client(self) -> "DDGS": # type: ignore[type-arg]
"""Lazy-loaded DDGS client."""
if self._ddgs_client is None:
self._ddgs_client = DDGS()

View File

@@ -11,7 +11,7 @@ from urllib.parse import urljoin
import httpx
import trafilatura
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup, Tag
from pydantic import BaseModel
from forge.agent.components import ConfigurableComponent
@@ -158,30 +158,30 @@ class WebFetchComponent(
# Meta description
desc = soup.find("meta", attrs={"name": "description"})
if desc and desc.get("content"):
metadata["description"] = desc["content"]
if isinstance(desc, Tag) and desc.get("content"):
metadata["description"] = str(desc["content"])
# Open Graph title/description
og_title = soup.find("meta", attrs={"property": "og:title"})
if og_title and og_title.get("content"):
metadata["og_title"] = og_title["content"]
if isinstance(og_title, Tag) and og_title.get("content"):
metadata["og_title"] = str(og_title["content"])
og_desc = soup.find("meta", attrs={"property": "og:description"})
if og_desc and og_desc.get("content"):
metadata["og_description"] = og_desc["content"]
if isinstance(og_desc, Tag) and og_desc.get("content"):
metadata["og_description"] = str(og_desc["content"])
# Author
author = soup.find("meta", attrs={"name": "author"})
if author and author.get("content"):
metadata["author"] = author["content"]
if isinstance(author, Tag) and author.get("content"):
metadata["author"] = str(author["content"])
# Published date
for attr in ["article:published_time", "datePublished", "date"]:
date_tag = soup.find("meta", attrs={"property": attr}) or soup.find(
"meta", attrs={"name": attr}
)
if date_tag and date_tag.get("content"):
metadata["published"] = date_tag["content"]
if isinstance(date_tag, Tag) and date_tag.get("content"):
metadata["published"] = str(date_tag["content"])
break
return metadata
@@ -253,12 +253,20 @@ class WebFetchComponent(
if output_format == "markdown":
content = trafilatura.extract(
html, output_format="markdown", **extract_kwargs
html,
output_format="markdown",
**extract_kwargs, # type: ignore[arg-type]
)
elif output_format == "xml":
content = trafilatura.extract(html, output_format="xml", **extract_kwargs)
content = trafilatura.extract(
html,
output_format="xml",
**extract_kwargs, # type: ignore[arg-type]
)
else:
content = trafilatura.extract(html, **extract_kwargs)
content = trafilatura.extract(
html, **extract_kwargs # type: ignore[arg-type]
)
if not content:
# Fallback to basic BeautifulSoup extraction

View File

@@ -250,7 +250,8 @@ class FileSyncHandler(FileSystemEventHandler):
if event.is_directory:
return
file_path = Path(event.src_path).relative_to(self.path)
src_path = str(event.src_path)
file_path = Path(src_path).relative_to(self.path)
content = file_path.read_bytes()
# Must execute write_file synchronously because the hook is synchronous
# TODO: Schedule write operation using asyncio.create_task (non-blocking)
@@ -259,11 +260,12 @@ class FileSyncHandler(FileSystemEventHandler):
)
def on_created(self, event: FileSystemEvent):
src_path = str(event.src_path)
if event.is_directory:
self.storage.make_dir(event.src_path)
self.storage.make_dir(src_path)
return
file_path = Path(event.src_path).relative_to(self.path)
file_path = Path(src_path).relative_to(self.path)
content = file_path.read_bytes()
# Must execute write_file synchronously because the hook is synchronous
# TODO: Schedule write operation using asyncio.create_task (non-blocking)
@@ -272,12 +274,12 @@ class FileSyncHandler(FileSystemEventHandler):
)
def on_deleted(self, event: FileSystemEvent):
src_path = str(event.src_path)
if event.is_directory:
self.storage.delete_dir(event.src_path)
self.storage.delete_dir(src_path)
return
file_path = event.src_path
self.storage.delete_file(file_path)
self.storage.delete_file(src_path)
def on_moved(self, event: FileSystemEvent):
self.storage.rename(event.src_path, event.dest_path)
self.storage.rename(str(event.src_path), str(event.dest_path))

View File

@@ -23,6 +23,7 @@ from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
CompletionCreateParams,
)
from openai.types.shared_params import FunctionDefinition
@@ -454,11 +455,13 @@ class BaseOpenAIChatProvider(
if assistant_message.tool_calls:
for _tc in assistant_message.tool_calls:
# Standard tool calls have a function attribute
tc = cast(ChatCompletionMessageToolCall, _tc)
try:
parsed_arguments = json_loads(_tc.function.arguments)
parsed_arguments = json_loads(tc.function.arguments)
except Exception as e:
err_message = (
f"Decoding arguments for {_tc.function.name} failed: "
f"Decoding arguments for {tc.function.name} failed: "
+ str(e.args[0])
)
parse_errors.append(
@@ -470,10 +473,10 @@ class BaseOpenAIChatProvider(
tool_calls.append(
AssistantToolCall(
id=_tc.id,
type=_tc.type,
id=tc.id,
type=tc.type,
function=AssistantFunctionCall(
name=_tc.function.name,
name=tc.function.name,
arguments=parsed_arguments,
),
)

View File

@@ -2,7 +2,7 @@ import enum
import logging
import re
from pathlib import Path
from typing import Any, Iterator, Optional, Sequence
from typing import Any, Iterator, Literal, Optional, Sequence, cast
import requests
from openai.types.chat import (
@@ -138,7 +138,7 @@ class LlamafileProvider(
self._logger.debug(f"Cleaned llamafile model IDs: {clean_model_ids}")
return [
LLAMAFILE_CHAT_MODELS[id]
LLAMAFILE_CHAT_MODELS[cast(LlamafileModelName, id)]
for id in clean_model_ids
if id in LLAMAFILE_CHAT_MODELS
]
@@ -199,12 +199,18 @@ class LlamafileProvider(
model: LlamafileModelName,
functions: list[CompletionModelFunction] | None = None,
max_output_tokens: int | None = None,
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
**kwargs,
) -> tuple[
list[ChatCompletionMessageParam], CompletionCreateParams, dict[str, Any]
]:
messages, completion_kwargs, parse_kwargs = super()._get_chat_completion_args(
prompt_messages, model, functions, max_output_tokens, **kwargs
prompt_messages,
model,
functions,
max_output_tokens,
reasoning_effort,
**kwargs,
)
if model == LlamafileModelName.MISTRAL_7B_INSTRUCT:
@@ -237,7 +243,8 @@ class LlamafileProvider(
See details here:
https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format
"""
adapted_messages: list[ChatCompletionMessageParam] = []
# Use Any for working list since we do runtime type transformations
adapted_messages: list[Any] = []
for message in messages:
# convert 'system' role to 'user' role as mistral-7b-instruct does
# not support 'system'
@@ -268,13 +275,14 @@ class LlamafileProvider(
else [{"type": "text", "text": message["content"]}]
)
elif message["role"] != "user" and last_message["role"] != "user":
# Non-user messages have string content
prev_content = str(last_message.get("content") or "")
curr_content = str(message.get("content") or "")
last_message["content"] = (
(last_message.get("content") or "")
+ "\n\n"
+ (message.get("content") or "")
prev_content + "\n\n" + curr_content
).strip()
return adapted_messages
return cast(list[ChatCompletionMessageParam], adapted_messages)
def _parse_assistant_tool_calls(
self,

View File

@@ -2,7 +2,17 @@ import enum
import logging
import os
from pathlib import Path
from typing import Any, Callable, Iterator, Mapping, Optional, ParamSpec, TypeVar, cast
from typing import (
Any,
Callable,
Iterator,
Literal,
Mapping,
Optional,
ParamSpec,
TypeVar,
cast,
)
import tenacity
import tiktoken
@@ -691,6 +701,7 @@ class OpenAIProvider(
model: OpenAIModelName,
functions: Optional[list[CompletionModelFunction]] = None,
max_output_tokens: Optional[int] = None,
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None,
**kwargs,
) -> tuple[
list[ChatCompletionMessageParam], CompletionCreateParams, dict[str, Any]
@@ -721,6 +732,7 @@ class OpenAIProvider(
model=model,
functions=functions,
max_output_tokens=max_output_tokens,
reasoning_effort=reasoning_effort,
**kwargs,
)
kwargs.update(self._credentials.get_model_access_kwargs(model)) # type: ignore

View File

@@ -1,6 +1,6 @@
import os
import typing
from typing import Any, Callable, Generic, Optional, Type, TypeVar, get_args
from typing import Any, Callable, Generic, Optional, Type, TypeVar, cast, get_args
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from pydantic._internal._model_construction import ( # HACK shouldn't be used
@@ -13,28 +13,53 @@ T = TypeVar("T")
M = TypeVar("M", bound=BaseModel)
def _call_default_factory(factory: Callable[..., Any]) -> Any:
"""Call a Pydantic default_factory.
Pydantic's type stubs incorrectly type default_factory, but at runtime
it's always Callable[[], T]. This helper provides proper typing.
"""
return cast(Callable[[], Any], factory)()
def UserConfigurable(
default: T | PydanticUndefinedType = PydanticUndefined,
*args,
*, # Force keyword-only arguments after default
default_factory: Optional[Callable[[], T]] = None,
from_env: Optional[str | Callable[[], T | None]] = None,
description: str = "",
exclude: bool = False,
**kwargs,
**kwargs: Any,
) -> T:
# TODO: use this to auto-generate docs for the application configuration
field_info: FieldInfo = Field(
default,
*args,
default_factory=default_factory,
description=description,
exclude=exclude,
**kwargs,
)
"""Create a user-configurable field with optional environment variable support.
Args:
default: Default value for the field
default_factory: Factory function to create default value
from_env: Environment variable name or callable to get value from environment
description: Field description
exclude: Whether to exclude from serialization
**kwargs: Additional arguments passed to Pydantic Field()
"""
# Handle Field() overload - it expects either default OR default_factory, not both
if default_factory is not None:
field_info: FieldInfo = Field(
default_factory=default_factory,
description=description,
exclude=exclude,
**kwargs,
)
else:
field_info = Field(
default=default,
description=description,
exclude=exclude,
**kwargs,
)
field_info.metadata.append(("user_configurable", True))
field_info.metadata.append(("from_env", from_env))
return field_info # type: ignore
return cast(T, field_info)
def _get_field_metadata(field: FieldInfo, key: str, default: Any = None) -> Any:
@@ -64,7 +89,7 @@ class SystemConfiguration(BaseModel):
field.default
if field.default not in (None, PydanticUndefined)
else (
field.default_factory()
_call_default_factory(field.default_factory)
if field.default_factory
else PydanticUndefined
)
@@ -141,7 +166,11 @@ def _update_user_config_from_env(instance: BaseModel) -> dict[str, Any]:
default_value = (
field.default
if field.default not in (None, PydanticUndefined)
else (field.default_factory() if field.default_factory else None)
else (
_call_default_factory(field.default_factory)
if field.default_factory
else None
)
)
if value == default_value and (
from_env := _get_field_metadata(field, "from_env")
@@ -330,7 +359,11 @@ def _get_non_default_user_config_values(instance: BaseModel) -> dict[str, Any]:
"""
def get_field_value(field: FieldInfo, value):
default = field.default_factory() if field.default_factory else field.default
default = (
_call_default_factory(field.default_factory)
if field.default_factory
else field.default
)
if value != default:
return value

View File

@@ -7,7 +7,7 @@ runner for agent creation.
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, cast
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
@@ -16,7 +16,7 @@ from forge.config.ai_directives import AIDirectives
from forge.config.ai_profile import AIProfile
if TYPE_CHECKING:
from autogpt.app.config import AppConfig
from autogpt.app.config import AppConfig, PromptStrategyName
class DefaultAgentFactory(AgentFactory):
@@ -99,7 +99,7 @@ class DefaultAgentFactory(AgentFactory):
# Copy app config and optionally override strategy
config = self.app_config.model_copy(deep=True)
if strategy:
config.prompt_strategy = strategy
config.prompt_strategy = cast("PromptStrategyName", strategy)
# Sub-agents should always be non-interactive
config.noninteractive_mode = True

View File

@@ -152,8 +152,10 @@ class Agent(BaseAgent[AnyActionProposal], Configurable[AgentSettings]):
# Create prompt strategy and inject execution context
self.prompt_strategy = self._create_prompt_strategy(app_config)
if hasattr(self.prompt_strategy, "set_execution_context"):
self.prompt_strategy.set_execution_context(self.execution_context)
# Multi-step strategies have set_execution_context; one_shot doesn't
set_ctx = getattr(self.prompt_strategy, "set_execution_context", None)
if set_ctx is not None:
set_ctx(self.execution_context)
self.commands: list[Command] = []
@@ -263,12 +265,11 @@ class Agent(BaseAgent[AnyActionProposal], Configurable[AgentSettings]):
# Prepare messages (lazy compression) - skip if strategy will use cached actions
# ReWOO EXECUTING phase doesn't need messages, so skip compression
skip_message_prep = False
if hasattr(self.prompt_strategy, "current_phase"):
current_phase = getattr(self.prompt_strategy, "current_phase", None)
if current_phase is not None:
from .prompt_strategies.rewoo import ReWOOPhase
skip_message_prep = (
self.prompt_strategy.current_phase == ReWOOPhase.EXECUTING
)
skip_message_prep = current_phase == ReWOOPhase.EXECUTING
if not skip_message_prep and hasattr(self, "history"):
await self.history.prepare_messages()
@@ -324,21 +325,21 @@ class Agent(BaseAgent[AnyActionProposal], Configurable[AgentSettings]):
thinking_kwargs: dict[str, Any] = {}
if hasattr(self, "app_config") and self.app_config:
if self.app_config.thinking_budget_tokens:
thinking_kwargs[
"thinking_budget_tokens"
] = self.app_config.thinking_budget_tokens
thinking_kwargs["thinking_budget_tokens"] = (
self.app_config.thinking_budget_tokens
)
if self.app_config.reasoning_effort:
thinking_kwargs["reasoning_effort"] = self.app_config.reasoning_effort
response: ChatModelResponse[
AnyActionProposal
] = await self.llm_provider.create_chat_completion(
prompt.messages,
model_name=self.llm.name,
completion_parser=self.prompt_strategy.parse_response_content,
functions=prompt.functions,
prefill_response=prompt.prefill_response,
**thinking_kwargs,
response: ChatModelResponse[AnyActionProposal] = (
await self.llm_provider.create_chat_completion(
prompt.messages,
model_name=self.llm.name,
completion_parser=self.prompt_strategy.parse_response_content,
functions=prompt.functions,
prefill_response=prompt.prefill_response,
**thinking_kwargs,
)
)
result = response.parsed_result
@@ -395,17 +396,16 @@ class Agent(BaseAgent[AnyActionProposal], Configurable[AgentSettings]):
# Notify ReWOO strategy of execution result for variable tracking
# This allows ReWOO to record results and substitute variables in later steps
if hasattr(self.prompt_strategy, "record_execution_result") and hasattr(
self.prompt_strategy, "current_plan"
):
plan = getattr(self.prompt_strategy, "current_plan", None)
if plan and plan.current_step_index < len(plan.steps):
record_result = getattr(self.prompt_strategy, "record_execution_result", None)
plan = getattr(self.prompt_strategy, "current_plan", None)
if record_result is not None and plan is not None:
if plan.current_step_index < len(plan.steps):
step = plan.steps[plan.current_step_index]
error_msg = None
if isinstance(result, ActionErrorResult):
error_msg = getattr(result, "reason", None) or str(result)
result_str = str(getattr(result, "outputs", result))
self.prompt_strategy.record_execution_result(
record_result(
step.variable_name,
result_str,
error=error_msg,

View File

@@ -194,10 +194,17 @@ class ReWOOPromptConfiguration(BasePromptStrategyConfiguration):
"2. Specify the tool to use and its arguments\n"
"3. Assign a variable name (#E1, #E2, etc.) to store the result\n"
"4. Later steps can reference earlier results using variable names\n\n"
"Format each step as:\n"
"Format each step EXACTLY as:\n"
"Plan: [Your reasoning for this step]\n"
'#E[n] = tool_name(arg1="value1", arg2=#E[m])\n\n'
"After all steps, provide the response in the required JSON format."
'#E1 = tool_name(arg1="value1", arg2="value2")\n\n'
"Example plan:\n"
"Plan: First, I need to list the files to understand the structure.\n"
'#E1 = list_folder(folder=".")\n'
"Plan: Next, I will read the main file to understand its contents.\n"
'#E2 = read_file(filename="main.py")\n'
"Plan: Finally, I will write the solution to a new file.\n"
'#E3 = write_to_file(filename="solution.txt", contents="The answer is 42")\n\n'
"Now create your plan following this EXACT format."
)
# Paper-style planner instruction (uses bracket syntax like the original paper)
@@ -315,10 +322,16 @@ class ReWOOPromptStrategy(BaseMultiStepPromptStrategy):
the cached action from the plan should be used instead of
making an LLM call. This is the core ReWOO optimization.
"""
self.logger.info(
f"ReWOO build_prompt: current_phase={self.current_phase.value}"
)
# EXECUTING phase: use pre-planned actions without LLM calls
if self.current_phase == ReWOOPhase.EXECUTING:
cached_action = self._get_cached_action_proposal()
if cached_action:
# current_plan is guaranteed to be set in EXECUTING phase
assert self.current_plan is not None
self.logger.debug(
f"ReWOO EXECUTING: Using cached action "
f"(step {self.current_plan.current_step_index + 1} "
@@ -427,7 +440,11 @@ class ReWOOPromptStrategy(BaseMultiStepPromptStrategy):
is_synthesis: bool,
) -> tuple[str, str]:
"""Build the system prompt."""
response_fmt_instruction, response_prefill = self._response_format_instruction()
# During planning, we want plan text format, not tool calls
is_planning = not is_synthesis and self.current_phase == ReWOOPhase.PLANNING
response_fmt_instruction, response_prefill = self._response_format_instruction(
is_planning=is_planning
)
system_prompt_parts = (
self.generate_intro_prompt(ai_profile)
@@ -450,8 +467,26 @@ class ReWOOPromptStrategy(BaseMultiStepPromptStrategy):
response_prefill,
)
def _response_format_instruction(self) -> tuple[str, str]:
"""Generate response format instruction."""
def _response_format_instruction(
self, is_planning: bool = False
) -> tuple[str, str]:
"""Generate response format instruction.
Args:
is_planning: If True, we're in PLANNING phase and want plan text,
not tool calls.
"""
if is_planning:
# During planning, we want the plan in text format, not tool calls
return (
"Output your plan following the EXACT format specified in the "
"instructions. Each step must have:\n"
'- "Plan:" followed by your reasoning\n'
'- "#E[n] = tool_name(arg1=\\"value1\\", ...)" on the next line\n\n'
"Do NOT call any tools directly - just write out the plan steps.",
"", # No prefill for planning
)
schema = self._response_schema.model_copy(deep=True)
assert schema.properties
@@ -510,6 +545,61 @@ class ReWOOPromptStrategy(BaseMultiStepPromptStrategy):
)
)
# Parse plan from response FIRST if in planning phase.
# During PLANNING, we expect plan text format, not JSON
if self.current_phase == ReWOOPhase.PLANNING:
self.logger.info("ReWOO: Attempting to extract plan from PLANNING response")
plan = self._extract_plan_from_response(response.content)
if plan and plan.steps:
self.current_plan = plan
# Transition to EXECUTING phase now that we have a plan
self.current_phase = ReWOOPhase.EXECUTING
self.logger.info(
f"ReWOO: Extracted plan with {len(plan.steps)} steps, "
f"transitioning to EXECUTING phase"
)
# Use the first step of the plan as the action to execute
first_step = plan.steps[0]
first_action = AssistantFunctionCall(
name=first_step.tool_name,
arguments=first_step.tool_arguments,
)
# Build a complete proposal from the plan
thoughts = ReWOOThoughts(
observations="Created ReWOO execution plan",
reasoning=first_step.thought,
plan=[f"{s.variable_name}: {s.thought}" for s in plan.steps],
)
# Create synthetic raw message
from forge.llm.providers.schema import AssistantToolCall
raw_message = AssistantChatMessage(
content=response.content,
tool_calls=[
AssistantToolCall(
id="rewoo_plan_step_0",
type="function",
function=first_action,
)
],
)
return ReWOOActionProposal(
thoughts=thoughts,
use_tool=first_action,
raw_message=raw_message,
)
else:
self.logger.warning(
"ReWOO: Failed to extract plan from response, staying in PLANNING. "
f"Plan: {plan}, Steps: {plan.steps if plan else 'N/A'}"
)
# Fall through to standard JSON parsing if plan extraction fails
# For non-planning phases or if plan extraction failed, parse as JSON
assistant_reply_dict = extract_dict_from_json(response.content)
self.logger.debug(
"Parsing object extracted from LLM response:\n"
@@ -521,18 +611,6 @@ class ReWOOPromptStrategy(BaseMultiStepPromptStrategy):
assistant_reply_dict["use_tool"] = response.tool_calls[0].function
# Parse plan from response if in planning phase
if self.current_phase == ReWOOPhase.PLANNING:
plan = self._extract_plan_from_response(response.content)
if plan and plan.steps:
self.current_plan = plan
# Transition to EXECUTING phase now that we have a plan
self.current_phase = ReWOOPhase.EXECUTING
self.logger.info(
f"ReWOO: Extracted plan with {len(plan.steps)} steps, "
f"transitioning to EXECUTING phase"
)
# Ensure thoughts dict has required fields
thoughts_dict = assistant_reply_dict.get("thoughts", {})
if not isinstance(thoughts_dict, dict):
@@ -555,6 +633,7 @@ class ReWOOPromptStrategy(BaseMultiStepPromptStrategy):
1. Paper-style bracket format: #E1 = Tool[argument]
2. Function-style parenthesis format: #E1 = tool(arg1="value1")
"""
self.logger.debug(f"ReWOO: Extracting plan from content:\n{content[:1000]}...")
plan = ReWOOPlan()
# Pattern for paper-style bracket format: #E1 = Tool[argument]
@@ -572,11 +651,19 @@ class ReWOOPromptStrategy(BaseMultiStepPromptStrategy):
# Try bracket format first (paper-style)
matches = bracket_pattern.findall(content)
is_bracket_format = bool(matches)
self.logger.debug(f"ReWOO: Bracket pattern matches: {len(matches)}")
if not matches:
# Fall back to parenthesis format
matches = paren_pattern.findall(content)
is_bracket_format = False
self.logger.debug(f"ReWOO: Paren pattern matches: {len(matches)}")
if not matches:
self.logger.warning(
"ReWOO: No plan pattern matched. Expected format:\n"
'Plan: [reasoning]\n#E1 = tool_name(arg="value")'
)
for match in matches[: self.config.max_plan_steps]:
thought, var_num, tool_name, args_str = match

View File

@@ -18,7 +18,13 @@ from forge.models.config import Configurable, UserConfigurable
# Type alias for prompt strategy options
PromptStrategyName = Literal[
"one_shot", "rewoo", "plan_execute", "reflexion", "tree_of_thoughts"
"one_shot",
"rewoo",
"plan_execute",
"reflexion",
"tree_of_thoughts",
"lats",
"multi_agent_debate",
]
logger = logging.getLogger(__name__)

View File

@@ -160,7 +160,13 @@ skip_glob = ["data"]
[tool.pyright]
pythonVersion = "3.12"
exclude = ["data/**", "**/node_modules", "**/__pycache__", "**/.*"]
exclude = [
"data/**",
"**/node_modules",
"**/__pycache__",
"**/.*",
"direct_benchmark/challenges/**", # Legacy code with unavailable imports
]
[tool.pytest.ini_options]