mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
34 Commits
autogpt-v0
...
fixing-lin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0571eb33e6 | ||
|
|
d07dd42776 | ||
|
|
557894e6c1 | ||
|
|
94f0cfd38a | ||
|
|
08c32a7a12 | ||
|
|
56104bd047 | ||
|
|
2ef5cd7d4c | ||
|
|
74b3aae5c6 | ||
|
|
e9b3b5090c | ||
|
|
9bac6f4ce2 | ||
|
|
39c46ef6be | ||
|
|
78d83bb3ce | ||
|
|
d57ccf7ec9 | ||
|
|
ada2e19829 | ||
|
|
a7c7a5e18b | ||
|
|
180de0c9a9 | ||
|
|
8f0d5c73b3 | ||
|
|
3b00e8229c | ||
|
|
e97726cde3 | ||
|
|
d38e8b8f6c | ||
|
|
0014e2ac14 | ||
|
|
370615e5e4 | ||
|
|
f93c743d03 | ||
|
|
6add645597 | ||
|
|
bdda3a6698 | ||
|
|
126aacb2e3 | ||
|
|
1afc8e40df | ||
|
|
9543e5d6ac | ||
|
|
5e89b8c6d1 | ||
|
|
fd3f8fa5fc | ||
|
|
86bdbb82b1 | ||
|
|
898317c16c | ||
|
|
0704404344 | ||
|
|
a74548d3cd |
2
.github/workflows/autogpts-ci.yml
vendored
2
.github/workflows/autogpts-ci.yml
vendored
@@ -34,7 +34,7 @@ jobs:
|
||||
fail-fast: false
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
min-python-version: '3.12'
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
6
.pr_agent.toml
Normal file
6
.pr_agent.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
[pr_reviewer]
|
||||
num_code_suggestions=0
|
||||
|
||||
[pr_code_suggestions]
|
||||
commitable_code_suggestions=false
|
||||
num_code_suggestions=0
|
||||
@@ -2,8 +2,11 @@
|
||||
### AutoGPT - GENERAL SETTINGS
|
||||
################################################################################
|
||||
|
||||
## OPENAI_API_KEY - OpenAI API Key (Example: my-openai-api-key)
|
||||
OPENAI_API_KEY=your-openai-api-key
|
||||
## OPENAI_API_KEY - OpenAI API Key (Example: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
|
||||
# OPENAI_API_KEY=
|
||||
|
||||
## ANTHROPIC_API_KEY - Anthropic API Key (Example: sk-ant-api03-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
|
||||
# ANTHROPIC_API_KEY=
|
||||
|
||||
## TELEMETRY_OPT_IN - Share telemetry on errors and other issues with the AutoGPT team, e.g. through Sentry.
|
||||
## This helps us to spot and solve problems earlier & faster. (Default: DISABLED)
|
||||
@@ -17,8 +20,8 @@ OPENAI_API_KEY=your-openai-api-key
|
||||
## RESTRICT_TO_WORKSPACE - Restrict file operations to workspace ./data/agents/<agent_id>/workspace (Default: True)
|
||||
# RESTRICT_TO_WORKSPACE=True
|
||||
|
||||
## DISABLED_COMMAND_CATEGORIES - The list of categories of commands that are disabled (Default: None)
|
||||
# DISABLED_COMMAND_CATEGORIES=
|
||||
## DISABLED_COMMANDS - The comma separated list of commands that are disabled (Default: None)
|
||||
# DISABLED_COMMANDS=
|
||||
|
||||
## FILE_STORAGE_BACKEND - Choose a storage backend for contents
|
||||
## Options: local, gcs, s3
|
||||
@@ -44,9 +47,6 @@ OPENAI_API_KEY=your-openai-api-key
|
||||
## AI_SETTINGS_FILE - Specifies which AI Settings file to use, relative to the AutoGPT root directory. (defaults to ai_settings.yaml)
|
||||
# AI_SETTINGS_FILE=ai_settings.yaml
|
||||
|
||||
## PLUGINS_CONFIG_FILE - The path to the plugins_config.yaml file, relative to the AutoGPT root directory. (Default plugins_config.yaml)
|
||||
# PLUGINS_CONFIG_FILE=plugins_config.yaml
|
||||
|
||||
## PROMPT_SETTINGS_FILE - Specifies which Prompt Settings file to use, relative to the AutoGPT root directory. (defaults to prompt_settings.yaml)
|
||||
# PROMPT_SETTINGS_FILE=prompt_settings.yaml
|
||||
|
||||
@@ -90,11 +90,11 @@ OPENAI_API_KEY=your-openai-api-key
|
||||
### LLM MODELS
|
||||
################################################################################
|
||||
|
||||
## SMART_LLM - Smart language model (Default: gpt-4-turbo-preview)
|
||||
# SMART_LLM=gpt-4-turbo-preview
|
||||
## SMART_LLM - Smart language model (Default: gpt-4-turbo)
|
||||
# SMART_LLM=gpt-4-turbo
|
||||
|
||||
## FAST_LLM - Fast language model (Default: gpt-3.5-turbo-0125)
|
||||
# FAST_LLM=gpt-3.5-turbo-0125
|
||||
## FAST_LLM - Fast language model (Default: gpt-3.5-turbo)
|
||||
# FAST_LLM=gpt-3.5-turbo
|
||||
|
||||
## EMBEDDING_MODEL - Model to use for creating embeddings
|
||||
# EMBEDDING_MODEL=text-embedding-3-small
|
||||
|
||||
@@ -22,6 +22,11 @@ repos:
|
||||
- id: black
|
||||
language_version: python3.10
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: 'v1.3.0'
|
||||
# hooks:
|
||||
|
||||
@@ -5,12 +5,10 @@ from pathlib import Path
|
||||
|
||||
from autogpt.agent_manager.agent_manager import AgentManager
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.app.main import _configure_openai_provider, run_interaction_loop
|
||||
from autogpt.commands import COMMAND_CATEGORIES
|
||||
from autogpt.app.main import _configure_llm_provider, run_interaction_loop
|
||||
from autogpt.config import AIProfile, ConfigBuilder
|
||||
from autogpt.file_storage import FileStorageBackendName, get_storage
|
||||
from autogpt.logs.config import configure_logging
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
LOG_DIR = Path(__file__).parent / "logs"
|
||||
|
||||
@@ -33,16 +31,12 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
config.noninteractive_mode = True
|
||||
config.memory_backend = "no_memory"
|
||||
|
||||
command_registry = CommandRegistry.with_command_modules(COMMAND_CATEGORIES, config)
|
||||
|
||||
ai_profile = AIProfile(
|
||||
ai_name="AutoGPT",
|
||||
ai_role="a multi-purpose AI assistant.",
|
||||
ai_goals=[task],
|
||||
)
|
||||
|
||||
agent_prompt_config = Agent.default_settings.prompt_config.copy(deep=True)
|
||||
agent_prompt_config.use_functions_api = config.openai_functions
|
||||
agent_settings = AgentSettings(
|
||||
name=Agent.default_settings.name,
|
||||
agent_id=AgentManager.generate_id("AutoGPT-benchmark"),
|
||||
@@ -53,9 +47,7 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
smart_llm=config.smart_llm,
|
||||
allow_fs_access=not config.restrict_to_workspace,
|
||||
use_functions_api=config.openai_functions,
|
||||
plugins=config.plugins,
|
||||
),
|
||||
prompt_config=agent_prompt_config,
|
||||
history=Agent.default_settings.history.copy(deep=True),
|
||||
)
|
||||
|
||||
@@ -68,8 +60,7 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
|
||||
agent = Agent(
|
||||
settings=agent_settings,
|
||||
llm_provider=_configure_openai_provider(config),
|
||||
command_registry=command_registry,
|
||||
llm_provider=_configure_llm_provider(config),
|
||||
file_storage=file_storage,
|
||||
legacy_config=config,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.commands import COMMAND_CATEGORIES
|
||||
from autogpt.config import AIDirectives, AIProfile, Config
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.logs.config import configure_chat_plugins
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
from autogpt.plugins import scan_plugins
|
||||
|
||||
|
||||
def create_agent(
|
||||
@@ -67,15 +63,6 @@ def _configure_agent(
|
||||
" must be specified"
|
||||
)
|
||||
|
||||
app_config.plugins = scan_plugins(app_config)
|
||||
configure_chat_plugins(app_config)
|
||||
|
||||
# Create a CommandRegistry instance and scan default folder
|
||||
command_registry = CommandRegistry.with_command_modules(
|
||||
modules=COMMAND_CATEGORIES,
|
||||
config=app_config,
|
||||
)
|
||||
|
||||
agent_state = state or create_agent_state(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
@@ -89,7 +76,6 @@ def _configure_agent(
|
||||
return Agent(
|
||||
settings=agent_state,
|
||||
llm_provider=llm_provider,
|
||||
command_registry=command_registry,
|
||||
file_storage=file_storage,
|
||||
legacy_config=app_config,
|
||||
)
|
||||
@@ -102,9 +88,6 @@ def create_agent_state(
|
||||
directives: AIDirectives,
|
||||
app_config: Config,
|
||||
) -> AgentSettings:
|
||||
agent_prompt_config = Agent.default_settings.prompt_config.copy(deep=True)
|
||||
agent_prompt_config.use_functions_api = app_config.openai_functions
|
||||
|
||||
return AgentSettings(
|
||||
agent_id=agent_id,
|
||||
name=Agent.default_settings.name,
|
||||
@@ -117,8 +100,6 @@ def create_agent_state(
|
||||
smart_llm=app_config.smart_llm,
|
||||
allow_fs_access=not app_config.restrict_to_workspace,
|
||||
use_functions_api=app_config.openai_functions,
|
||||
plugins=app_config.plugins,
|
||||
),
|
||||
prompt_config=agent_prompt_config,
|
||||
history=Agent.default_settings.history.copy(deep=True),
|
||||
)
|
||||
|
||||
@@ -20,9 +20,9 @@ class AgentManager:
|
||||
def list_agents(self) -> list[str]:
|
||||
"""Return all agent directories within storage."""
|
||||
agent_dirs: list[str] = []
|
||||
for dir in self.file_manager.list_folders():
|
||||
if self.file_manager.exists(dir / "state.json"):
|
||||
agent_dirs.append(dir.name)
|
||||
for file_path in self.file_manager.list_files():
|
||||
if len(file_path.parts) == 2 and file_path.name == "state.json":
|
||||
agent_dirs.append(file_path.parent.name)
|
||||
return agent_dirs
|
||||
|
||||
def get_agent_dir(self, agent_id: str) -> Path:
|
||||
|
||||
37
autogpts/autogpt/autogpt/agents/README.md
Normal file
37
autogpts/autogpt/autogpt/agents/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# 🤖 Agents
|
||||
|
||||
Agent is composed of [🧩 Components](./components.md) and responsible for executing pipelines and some additional logic. The base class for all agents is `BaseAgent`, it has the necessary logic to collect components and execute protocols.
|
||||
|
||||
## Important methods
|
||||
|
||||
`BaseAgent` provides two abstract methods needed for any agent to work properly:
|
||||
1. `propose_action`: This method is responsible for proposing an action based on the current state of the agent, it returns `ThoughtProcessOutput`.
|
||||
2. `execute`: This method is responsible for executing the proposed action, returns `ActionResult`.
|
||||
|
||||
## AutoGPT Agent
|
||||
|
||||
`Agent` is the main agent provided by AutoGPT. It's a subclass of `BaseAgent`. It has all the [Built-in Components](./built-in-components.md). `Agent` implements the essential abstract methods from `BaseAgent`: `propose_action` and `execute`.
|
||||
|
||||
## Building your own Agent
|
||||
|
||||
The easiest way to build your own agent is to extend the `Agent` class and add additional components. By doing this you can reuse the existing components and the default logic for executing [⚙️ Protocols](./protocols.md).
|
||||
|
||||
```py
|
||||
class MyComponent(AgentComponent):
|
||||
pass
|
||||
|
||||
class MyAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
settings: AgentSettings,
|
||||
llm_provider: ChatModelProvider,
|
||||
file_storage: FileStorage,
|
||||
legacy_config: Config,
|
||||
):
|
||||
# Call the parent constructor to bring in the default components
|
||||
super().__init__(settings, llm_provider, file_storage, legacy_config)
|
||||
# Add your custom component
|
||||
self.my_component = MyComponent()
|
||||
```
|
||||
|
||||
For more customization, you can override the `propose_action` and `execute` or even subclass `BaseAgent` directly. This way you can have full control over the agent's components and behavior. Have a look at the [implementation of Agent](https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpts/autogpt/autogpt/agents/agent.py) for more details.
|
||||
@@ -1,4 +1,9 @@
|
||||
from .agent import Agent
|
||||
from .base import AgentThoughts, BaseAgent, CommandArgs, CommandName
|
||||
from .agent import Agent, OneShotAgentActionProposal
|
||||
from .base import BaseAgent, BaseAgentActionProposal
|
||||
|
||||
__all__ = ["BaseAgent", "Agent", "CommandName", "CommandArgs", "AgentThoughts"]
|
||||
__all__ = [
|
||||
"BaseAgent",
|
||||
"Agent",
|
||||
"BaseAgentActionProposal",
|
||||
"OneShotAgentActionProposal",
|
||||
]
|
||||
|
||||
@@ -2,57 +2,70 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import sentry_sdk
|
||||
from pydantic import Field
|
||||
|
||||
from autogpt.commands.execute_code import CodeExecutorComponent
|
||||
from autogpt.commands.git_operations import GitOperationsComponent
|
||||
from autogpt.commands.image_gen import ImageGeneratorComponent
|
||||
from autogpt.commands.system import SystemComponent
|
||||
from autogpt.commands.user_interaction import UserInteractionComponent
|
||||
from autogpt.commands.web_search import WebSearchComponent
|
||||
from autogpt.commands.web_selenium import WebSeleniumComponent
|
||||
from autogpt.components.event_history import EventHistoryComponent
|
||||
from autogpt.core.configuration import Configurable
|
||||
from autogpt.core.prompting import ChatPrompt
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
ChatMessage,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
)
|
||||
from autogpt.core.runner.client_lib.logging.helpers import dump_prompt
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.llm.providers.openai import function_specs_from_commands
|
||||
from autogpt.logs.log_cycle import (
|
||||
CURRENT_CONTEXT_FILE_NAME,
|
||||
NEXT_ACTION_FILE_NAME,
|
||||
USER_INPUT_FILE_NAME,
|
||||
LogCycleHandler,
|
||||
)
|
||||
from autogpt.logs.utils import fmt_kwargs
|
||||
from autogpt.models.action_history import (
|
||||
Action,
|
||||
ActionErrorResult,
|
||||
ActionInterruptedByHuman,
|
||||
ActionResult,
|
||||
ActionSuccessResult,
|
||||
EpisodicActionHistory,
|
||||
)
|
||||
from autogpt.models.command import CommandOutput
|
||||
from autogpt.models.context_item import ContextItem
|
||||
|
||||
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
from .features.agent_file_manager import AgentFileManagerMixin
|
||||
from .features.context import ContextMixin
|
||||
from .features.watchdog import WatchdogMixin
|
||||
from .prompt_strategies.one_shot import (
|
||||
OneShotAgentPromptConfiguration,
|
||||
OneShotAgentPromptStrategy,
|
||||
)
|
||||
from .utils.exceptions import (
|
||||
from autogpt.models.command import Command, CommandOutput
|
||||
from autogpt.utils.exceptions import (
|
||||
AgentException,
|
||||
AgentTerminated,
|
||||
CommandExecutionError,
|
||||
DuplicateOperationError,
|
||||
UnknownCommandError,
|
||||
)
|
||||
|
||||
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
from .features.agent_file_manager import FileManagerComponent
|
||||
from .features.context import ContextComponent
|
||||
from .features.watchdog import WatchdogComponent
|
||||
from .prompt_strategies.one_shot import (
|
||||
OneShotAgentActionProposal,
|
||||
OneShotAgentPromptStrategy,
|
||||
)
|
||||
from .protocols import (
|
||||
AfterExecute,
|
||||
AfterParse,
|
||||
CommandProvider,
|
||||
DirectiveProvider,
|
||||
MessageProvider,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -63,49 +76,64 @@ class AgentConfiguration(BaseAgentConfiguration):
|
||||
|
||||
class AgentSettings(BaseAgentSettings):
|
||||
config: AgentConfiguration = Field(default_factory=AgentConfiguration)
|
||||
prompt_config: OneShotAgentPromptConfiguration = Field(
|
||||
default_factory=(
|
||||
lambda: OneShotAgentPromptStrategy.default_configuration.copy(deep=True)
|
||||
)
|
||||
|
||||
history: EpisodicActionHistory[OneShotAgentActionProposal] = Field(
|
||||
default_factory=EpisodicActionHistory[OneShotAgentActionProposal]
|
||||
)
|
||||
"""(STATE) The action history of the agent."""
|
||||
|
||||
|
||||
class Agent(
|
||||
ContextMixin,
|
||||
AgentFileManagerMixin,
|
||||
WatchdogMixin,
|
||||
BaseAgent,
|
||||
Configurable[AgentSettings],
|
||||
):
|
||||
"""AutoGPT's primary Agent; uses one-shot prompting."""
|
||||
|
||||
class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
default_settings: AgentSettings = AgentSettings(
|
||||
name="Agent",
|
||||
description=__doc__,
|
||||
description=__doc__ if __doc__ else "",
|
||||
)
|
||||
|
||||
prompt_strategy: OneShotAgentPromptStrategy
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: AgentSettings,
|
||||
llm_provider: ChatModelProvider,
|
||||
command_registry: CommandRegistry,
|
||||
file_storage: FileStorage,
|
||||
legacy_config: Config,
|
||||
):
|
||||
prompt_strategy = OneShotAgentPromptStrategy(
|
||||
configuration=settings.prompt_config,
|
||||
logger=logger,
|
||||
super().__init__(settings)
|
||||
|
||||
self.llm_provider = llm_provider
|
||||
self.ai_profile = settings.ai_profile
|
||||
self.directives = settings.directives
|
||||
prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(deep=True)
|
||||
prompt_config.use_functions_api = (
|
||||
settings.config.use_functions_api
|
||||
# Anthropic currently doesn't support tools + prefilling :(
|
||||
and self.llm.provider_name != "anthropic"
|
||||
)
|
||||
super().__init__(
|
||||
settings=settings,
|
||||
llm_provider=llm_provider,
|
||||
prompt_strategy=prompt_strategy,
|
||||
command_registry=command_registry,
|
||||
file_storage=file_storage,
|
||||
legacy_config=legacy_config,
|
||||
self.prompt_strategy = OneShotAgentPromptStrategy(prompt_config, logger)
|
||||
self.commands: list[Command] = []
|
||||
|
||||
# Components
|
||||
self.system = SystemComponent(legacy_config, settings.ai_profile)
|
||||
self.history = EventHistoryComponent(
|
||||
settings.history,
|
||||
self.send_token_limit,
|
||||
lambda x: self.llm_provider.count_tokens(x, self.llm.name),
|
||||
legacy_config,
|
||||
llm_provider,
|
||||
)
|
||||
self.user_interaction = UserInteractionComponent(legacy_config)
|
||||
self.file_manager = FileManagerComponent(settings, file_storage)
|
||||
self.code_executor = CodeExecutorComponent(
|
||||
self.file_manager.workspace,
|
||||
settings,
|
||||
legacy_config,
|
||||
)
|
||||
self.git_ops = GitOperationsComponent(legacy_config)
|
||||
self.image_gen = ImageGeneratorComponent(
|
||||
self.file_manager.workspace, legacy_config
|
||||
)
|
||||
self.web_search = WebSearchComponent(legacy_config)
|
||||
self.web_selenium = WebSeleniumComponent(legacy_config, llm_provider, self.llm)
|
||||
self.context = ContextComponent(self.file_manager.workspace)
|
||||
self.watchdog = WatchdogComponent(settings.config, settings.history)
|
||||
|
||||
self.created_at = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
"""Timestamp the agent was created; only used for structured debug logging."""
|
||||
@@ -113,186 +141,153 @@ class Agent(
|
||||
self.log_cycle_handler = LogCycleHandler()
|
||||
"""LogCycleHandler for structured debug logging."""
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
*args,
|
||||
extra_messages: Optional[list[ChatMessage]] = None,
|
||||
include_os_info: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> ChatPrompt:
|
||||
if not extra_messages:
|
||||
extra_messages = []
|
||||
self.event_history = settings.history
|
||||
self.legacy_config = legacy_config
|
||||
|
||||
# Clock
|
||||
extra_messages.append(
|
||||
ChatMessage.system(f"The current time and date is {time.strftime('%c')}"),
|
||||
async def propose_action(self) -> OneShotAgentActionProposal:
|
||||
"""Proposes the next action to execute, based on the task and current state.
|
||||
|
||||
Returns:
|
||||
The command name and arguments, if any, and the agent's thoughts.
|
||||
"""
|
||||
self.reset_trace()
|
||||
|
||||
# Get directives
|
||||
resources = await self.run_pipeline(DirectiveProvider.get_resources)
|
||||
constraints = await self.run_pipeline(DirectiveProvider.get_constraints)
|
||||
best_practices = await self.run_pipeline(DirectiveProvider.get_best_practices)
|
||||
|
||||
directives = self.state.directives.copy(deep=True)
|
||||
directives.resources += resources
|
||||
directives.constraints += constraints
|
||||
directives.best_practices += best_practices
|
||||
|
||||
# Get commands
|
||||
self.commands = await self.run_pipeline(CommandProvider.get_commands)
|
||||
self._remove_disabled_commands()
|
||||
|
||||
# Get messages
|
||||
messages = await self.run_pipeline(MessageProvider.get_messages)
|
||||
|
||||
prompt: ChatPrompt = self.prompt_strategy.build_prompt(
|
||||
messages=messages,
|
||||
task=self.state.task,
|
||||
ai_profile=self.state.ai_profile,
|
||||
ai_directives=directives,
|
||||
commands=function_specs_from_commands(self.commands),
|
||||
include_os_info=self.legacy_config.execute_local_commands,
|
||||
)
|
||||
|
||||
if include_os_info is None:
|
||||
include_os_info = self.legacy_config.execute_local_commands
|
||||
|
||||
return super().build_prompt(
|
||||
*args,
|
||||
extra_messages=extra_messages,
|
||||
include_os_info=include_os_info,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_before_think(self, *args, **kwargs) -> ChatPrompt:
|
||||
prompt = super().on_before_think(*args, **kwargs)
|
||||
|
||||
self.log_cycle_handler.log_count_within_cycle = 0
|
||||
self.log_cycle_handler.log_cycle(
|
||||
self.ai_profile.ai_name,
|
||||
self.state.ai_profile.ai_name,
|
||||
self.created_at,
|
||||
self.config.cycle_count,
|
||||
prompt.raw(),
|
||||
CURRENT_CONTEXT_FILE_NAME,
|
||||
)
|
||||
return prompt
|
||||
|
||||
def parse_and_process_response(
|
||||
self, llm_response: AssistantChatMessage, *args, **kwargs
|
||||
) -> Agent.ThoughtProcessOutput:
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_post_planning():
|
||||
continue
|
||||
llm_response.content = plugin.post_planning(llm_response.content or "")
|
||||
logger.debug(f"Executing prompt:\n{dump_prompt(prompt)}")
|
||||
output = await self.complete_and_parse(prompt)
|
||||
self.config.cycle_count += 1
|
||||
|
||||
(
|
||||
command_name,
|
||||
arguments,
|
||||
assistant_reply_dict,
|
||||
) = self.prompt_strategy.parse_response_content(llm_response)
|
||||
return output
|
||||
|
||||
# Check if command_name and arguments are already in the event_history
|
||||
if self.event_history.matches_last_command(command_name, arguments):
|
||||
raise DuplicateOperationError(
|
||||
f"The command {command_name} with arguments {arguments} "
|
||||
f"has been just executed."
|
||||
)
|
||||
async def complete_and_parse(
|
||||
self, prompt: ChatPrompt, exception: Optional[Exception] = None
|
||||
) -> OneShotAgentActionProposal:
|
||||
if exception:
|
||||
prompt.messages.append(ChatMessage.system(f"Error: {exception}"))
|
||||
|
||||
response: ChatModelResponse[
|
||||
OneShotAgentActionProposal
|
||||
] = await self.llm_provider.create_chat_completion(
|
||||
prompt.messages,
|
||||
model_name=self.llm.name,
|
||||
completion_parser=self.prompt_strategy.parse_response_content,
|
||||
functions=prompt.functions,
|
||||
prefill_response=prompt.prefill_response,
|
||||
)
|
||||
result = response.parsed_result
|
||||
|
||||
self.log_cycle_handler.log_cycle(
|
||||
self.ai_profile.ai_name,
|
||||
self.state.ai_profile.ai_name,
|
||||
self.created_at,
|
||||
self.config.cycle_count,
|
||||
assistant_reply_dict,
|
||||
result.thoughts.dict(),
|
||||
NEXT_ACTION_FILE_NAME,
|
||||
)
|
||||
|
||||
if command_name:
|
||||
self.event_history.register_action(
|
||||
Action(
|
||||
name=command_name,
|
||||
args=arguments,
|
||||
reasoning=assistant_reply_dict["thoughts"]["reasoning"],
|
||||
)
|
||||
)
|
||||
|
||||
return command_name, arguments, assistant_reply_dict
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
command_name: str,
|
||||
command_args: dict[str, str] = {},
|
||||
user_input: str = "",
|
||||
) -> ActionResult:
|
||||
result: ActionResult
|
||||
|
||||
if command_name == "human_feedback":
|
||||
result = ActionInterruptedByHuman(feedback=user_input)
|
||||
self.log_cycle_handler.log_cycle(
|
||||
self.ai_profile.ai_name,
|
||||
self.created_at,
|
||||
self.config.cycle_count,
|
||||
user_input,
|
||||
USER_INPUT_FILE_NAME,
|
||||
)
|
||||
|
||||
else:
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_pre_command():
|
||||
continue
|
||||
command_name, command_args = plugin.pre_command(
|
||||
command_name, command_args
|
||||
)
|
||||
|
||||
try:
|
||||
return_value = await execute_command(
|
||||
command_name=command_name,
|
||||
arguments=command_args,
|
||||
agent=self,
|
||||
)
|
||||
|
||||
# Intercept ContextItem if one is returned by the command
|
||||
if type(return_value) is tuple and isinstance(
|
||||
return_value[1], ContextItem
|
||||
):
|
||||
context_item = return_value[1]
|
||||
return_value = return_value[0]
|
||||
logger.debug(
|
||||
f"Command {command_name} returned a ContextItem: {context_item}"
|
||||
)
|
||||
self.context.add(context_item)
|
||||
|
||||
result = ActionSuccessResult(outputs=return_value)
|
||||
except AgentTerminated:
|
||||
raise
|
||||
except AgentException as e:
|
||||
result = ActionErrorResult.from_exception(e)
|
||||
logger.warning(
|
||||
f"{command_name}({fmt_kwargs(command_args)}) raised an error: {e}"
|
||||
)
|
||||
sentry_sdk.capture_exception(e)
|
||||
|
||||
result_tlength = self.llm_provider.count_tokens(str(result), self.llm.name)
|
||||
if result_tlength > self.send_token_limit // 3:
|
||||
result = ActionErrorResult(
|
||||
reason=f"Command {command_name} returned too much output. "
|
||||
"Do not execute this command again with the same arguments."
|
||||
)
|
||||
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_post_command():
|
||||
continue
|
||||
if result.status == "success":
|
||||
result.outputs = plugin.post_command(command_name, result.outputs)
|
||||
elif result.status == "error":
|
||||
result.reason = plugin.post_command(command_name, result.reason)
|
||||
|
||||
# Update action history
|
||||
self.event_history.register_result(result)
|
||||
await self.event_history.handle_compression(
|
||||
self.llm_provider, self.legacy_config
|
||||
)
|
||||
await self.run_pipeline(AfterParse.after_parse, result)
|
||||
|
||||
return result
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
proposal: OneShotAgentActionProposal,
|
||||
user_feedback: str = "",
|
||||
) -> ActionResult:
|
||||
tool = proposal.use_tool
|
||||
|
||||
#############
|
||||
# Utilities #
|
||||
#############
|
||||
# Get commands
|
||||
self.commands = await self.run_pipeline(CommandProvider.get_commands)
|
||||
self._remove_disabled_commands()
|
||||
|
||||
|
||||
async def execute_command(
|
||||
command_name: str,
|
||||
arguments: dict[str, str],
|
||||
agent: Agent,
|
||||
) -> CommandOutput:
|
||||
"""Execute the command and return the result
|
||||
|
||||
Args:
|
||||
command_name (str): The name of the command to execute
|
||||
arguments (dict): The arguments for the command
|
||||
agent (Agent): The agent that is executing the command
|
||||
|
||||
Returns:
|
||||
str: The result of the command
|
||||
"""
|
||||
# Execute a native command with the same name or alias, if it exists
|
||||
if command := agent.command_registry.get_command(command_name):
|
||||
try:
|
||||
result = command(**arguments, agent=agent)
|
||||
return_value = await self._execute_tool(tool)
|
||||
|
||||
result = ActionSuccessResult(outputs=return_value)
|
||||
except AgentTerminated:
|
||||
raise
|
||||
except AgentException as e:
|
||||
result = ActionErrorResult.from_exception(e)
|
||||
logger.warning(f"{tool} raised an error: {e}")
|
||||
sentry_sdk.capture_exception(e)
|
||||
|
||||
result_tlength = self.llm_provider.count_tokens(str(result), self.llm.name)
|
||||
if result_tlength > self.send_token_limit // 3:
|
||||
result = ActionErrorResult(
|
||||
reason=f"Command {tool.name} returned too much output. "
|
||||
"Do not execute this command again with the same arguments."
|
||||
)
|
||||
|
||||
await self.run_pipeline(AfterExecute.after_execute, result)
|
||||
|
||||
logger.debug("\n".join(self.trace))
|
||||
|
||||
return result
|
||||
|
||||
async def do_not_execute(
|
||||
self, denied_proposal: OneShotAgentActionProposal, user_feedback: str
|
||||
) -> ActionResult:
|
||||
result = ActionInterruptedByHuman(feedback=user_feedback)
|
||||
self.log_cycle_handler.log_cycle(
|
||||
self.state.ai_profile.ai_name,
|
||||
self.created_at,
|
||||
self.config.cycle_count,
|
||||
user_feedback,
|
||||
USER_INPUT_FILE_NAME,
|
||||
)
|
||||
|
||||
await self.run_pipeline(AfterExecute.after_execute, result)
|
||||
|
||||
logger.debug("\n".join(self.trace))
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_tool(self, tool_call: AssistantFunctionCall) -> CommandOutput:
|
||||
"""Execute the command and return the result
|
||||
|
||||
Args:
|
||||
tool_call (AssistantFunctionCall): The tool call to execute
|
||||
|
||||
Returns:
|
||||
str: The execution result
|
||||
"""
|
||||
# Execute a native command with the same name or alias, if it exists
|
||||
command = self._get_command(tool_call.name)
|
||||
try:
|
||||
result = command(**tool_call.arguments)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
@@ -301,20 +296,31 @@ async def execute_command(
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(str(e))
|
||||
|
||||
# Handle non-native commands (e.g. from plugins)
|
||||
if agent._prompt_scratchpad:
|
||||
for name, command in agent._prompt_scratchpad.commands.items():
|
||||
if (
|
||||
command_name == name
|
||||
or command_name.lower() == command.description.lower()
|
||||
):
|
||||
try:
|
||||
return command.method(**arguments)
|
||||
except AgentException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(str(e))
|
||||
def _get_command(self, command_name: str) -> Command:
|
||||
for command in reversed(self.commands):
|
||||
if command_name in command.names:
|
||||
return command
|
||||
|
||||
raise UnknownCommandError(
|
||||
f"Cannot execute command '{command_name}': unknown command."
|
||||
)
|
||||
raise UnknownCommandError(
|
||||
f"Cannot execute command '{command_name}': unknown command."
|
||||
)
|
||||
|
||||
def _remove_disabled_commands(self) -> None:
|
||||
self.commands = [
|
||||
command
|
||||
for command in self.commands
|
||||
if not any(
|
||||
name in self.legacy_config.disabled_commands for name in command.names
|
||||
)
|
||||
]
|
||||
|
||||
def find_obscured_commands(self) -> list[Command]:
|
||||
seen_names = set()
|
||||
obscured_commands = []
|
||||
for command in reversed(self.commands):
|
||||
# If all of the command's names have been seen, it's obscured
|
||||
if seen_names.issuperset(command.names):
|
||||
obscured_commands.append(command)
|
||||
else:
|
||||
seen_names.update(command.names)
|
||||
return list(reversed(obscured_commands))
|
||||
|
||||
@@ -1,24 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Iterator,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from pydantic import Field, validator
|
||||
from colorama import Fore
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.prompting.base import PromptStrategy
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
)
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
from autogpt.models.action_history import ActionResult
|
||||
|
||||
from autogpt.agents.utils.prompt_scratchpad import PromptScratchpad
|
||||
from autogpt.agents import protocols as _protocols
|
||||
from autogpt.agents.components import (
|
||||
AgentComponent,
|
||||
ComponentEndpointError,
|
||||
EndpointPipelineError,
|
||||
)
|
||||
from autogpt.config import ConfigBuilder
|
||||
from autogpt.config.ai_directives import AIDirectives
|
||||
from autogpt.config.ai_profile import AIProfile
|
||||
@@ -28,33 +39,26 @@ from autogpt.core.configuration import (
|
||||
SystemSettings,
|
||||
UserConfigurable,
|
||||
)
|
||||
from autogpt.core.prompting.schema import (
|
||||
ChatMessage,
|
||||
ChatPrompt,
|
||||
CompletionModelFunction,
|
||||
from autogpt.core.resource.model_providers import (
|
||||
CHAT_MODELS,
|
||||
AssistantFunctionCall,
|
||||
ModelName,
|
||||
)
|
||||
from autogpt.core.resource.model_providers.openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OpenAIModelName,
|
||||
)
|
||||
from autogpt.core.runner.client_lib.logging.helpers import dump_prompt
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.llm.providers.openai import get_openai_command_specs
|
||||
from autogpt.models.action_history import ActionResult, EpisodicActionHistory
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIModelName
|
||||
from autogpt.models.utils import ModelWithSummary
|
||||
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CommandName = str
|
||||
CommandArgs = dict[str, str]
|
||||
AgentThoughts = dict[str, Any]
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class BaseAgentConfiguration(SystemConfiguration):
|
||||
allow_fs_access: bool = UserConfigurable(default=False)
|
||||
|
||||
fast_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
|
||||
smart_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT4)
|
||||
fast_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
|
||||
smart_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT4)
|
||||
use_functions_api: bool = UserConfigurable(default=False)
|
||||
|
||||
default_cycle_instruction: str = DEFAULT_TRIGGERING_PROMPT
|
||||
@@ -90,21 +94,6 @@ class BaseAgentConfiguration(SystemConfiguration):
|
||||
summary_max_tlength: Optional[int] = None
|
||||
# TODO: move to ActionHistoryConfiguration
|
||||
|
||||
plugins: list[AutoGPTPluginTemplate] = Field(default_factory=list, exclude=True)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True # Necessary for plugins
|
||||
|
||||
@validator("plugins", each_item=True)
|
||||
def validate_plugins(cls, p: AutoGPTPluginTemplate | Any):
|
||||
assert issubclass(
|
||||
p.__class__, AutoGPTPluginTemplate
|
||||
), f"{p} does not subclass AutoGPTPluginTemplate"
|
||||
assert (
|
||||
p.__class__.__name__ != "AutoGPTPluginTemplate"
|
||||
), f"Plugins must subclass AutoGPTPluginTemplate; {p} is a template instance"
|
||||
return p
|
||||
|
||||
@validator("use_functions_api")
|
||||
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
|
||||
if v:
|
||||
@@ -141,51 +130,44 @@ class BaseAgentSettings(SystemSettings):
|
||||
config: BaseAgentConfiguration = Field(default_factory=BaseAgentConfiguration)
|
||||
"""The configuration for this BaseAgent subsystem instance."""
|
||||
|
||||
history: EpisodicActionHistory = Field(default_factory=EpisodicActionHistory)
|
||||
"""(STATE) The action history of the agent."""
|
||||
|
||||
class AgentMeta(ABCMeta):
|
||||
def __call__(cls, *args, **kwargs):
|
||||
# Create instance of the class (Agent or BaseAgent)
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
# Automatically collect modules after the instance is created
|
||||
instance._collect_components()
|
||||
return instance
|
||||
|
||||
|
||||
class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
||||
"""Base class for all AutoGPT agent classes."""
|
||||
class BaseAgentActionProposal(BaseModel):
|
||||
thoughts: str | ModelWithSummary
|
||||
use_tool: AssistantFunctionCall = None
|
||||
|
||||
ThoughtProcessOutput = tuple[CommandName, CommandArgs, AgentThoughts]
|
||||
|
||||
class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
C = TypeVar("C", bound=AgentComponent)
|
||||
|
||||
default_settings = BaseAgentSettings(
|
||||
name="BaseAgent",
|
||||
description=__doc__,
|
||||
description=__doc__ if __doc__ else "",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: BaseAgentSettings,
|
||||
llm_provider: ChatModelProvider,
|
||||
prompt_strategy: PromptStrategy,
|
||||
command_registry: CommandRegistry,
|
||||
file_storage: FileStorage,
|
||||
legacy_config: Config,
|
||||
):
|
||||
self.state = settings
|
||||
self.components: list[AgentComponent] = []
|
||||
self.config = settings.config
|
||||
self.ai_profile = settings.ai_profile
|
||||
self.directives = settings.directives
|
||||
self.event_history = settings.history
|
||||
# Execution data for debugging
|
||||
self._trace: list[str] = []
|
||||
|
||||
self.legacy_config = legacy_config
|
||||
"""LEGACY: Monolithic application configuration."""
|
||||
logger.debug(f"Created {__class__} '{self.state.ai_profile.ai_name}'")
|
||||
|
||||
self.llm_provider = llm_provider
|
||||
|
||||
self.prompt_strategy = prompt_strategy
|
||||
|
||||
self.command_registry = command_registry
|
||||
"""The registry containing all commands available to the agent."""
|
||||
|
||||
self._prompt_scratchpad: PromptScratchpad | None = None
|
||||
|
||||
# Support multi-inheritance and mixins for subclasses
|
||||
super(BaseAgent, self).__init__()
|
||||
|
||||
logger.debug(f"Created {__class__} '{self.ai_profile.ai_name}'")
|
||||
@property
|
||||
def trace(self) -> list[str]:
|
||||
return self._trace
|
||||
|
||||
@property
|
||||
def llm(self) -> ChatModelInfo:
|
||||
@@ -193,204 +175,180 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
||||
llm_name = (
|
||||
self.config.smart_llm if self.config.big_brain else self.config.fast_llm
|
||||
)
|
||||
return OPEN_AI_CHAT_MODELS[llm_name]
|
||||
return CHAT_MODELS[llm_name]
|
||||
|
||||
@property
|
||||
def send_token_limit(self) -> int:
|
||||
return self.config.send_token_limit or self.llm.max_tokens * 3 // 4
|
||||
|
||||
async def propose_action(self) -> ThoughtProcessOutput:
|
||||
"""Proposes the next action to execute, based on the task and current state.
|
||||
|
||||
Returns:
|
||||
The command name and arguments, if any, and the agent's thoughts.
|
||||
"""
|
||||
|
||||
# Scratchpad as surrogate PromptGenerator for plugin hooks
|
||||
self._prompt_scratchpad = PromptScratchpad()
|
||||
|
||||
prompt: ChatPrompt = self.build_prompt(scratchpad=self._prompt_scratchpad)
|
||||
prompt = self.on_before_think(prompt, scratchpad=self._prompt_scratchpad)
|
||||
|
||||
logger.debug(f"Executing prompt:\n{dump_prompt(prompt)}")
|
||||
response = await self.llm_provider.create_chat_completion(
|
||||
prompt.messages,
|
||||
functions=get_openai_command_specs(
|
||||
self.command_registry.list_available_commands(self)
|
||||
)
|
||||
+ list(self._prompt_scratchpad.commands.values())
|
||||
if self.config.use_functions_api
|
||||
else [],
|
||||
model_name=self.llm.name,
|
||||
completion_parser=lambda r: self.parse_and_process_response(
|
||||
r,
|
||||
prompt,
|
||||
scratchpad=self._prompt_scratchpad,
|
||||
),
|
||||
)
|
||||
self.config.cycle_count += 1
|
||||
|
||||
return self.on_response(
|
||||
llm_response=response,
|
||||
prompt=prompt,
|
||||
scratchpad=self._prompt_scratchpad,
|
||||
)
|
||||
@abstractmethod
|
||||
async def propose_action(self) -> BaseAgentActionProposal:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def execute(
|
||||
self,
|
||||
command_name: str,
|
||||
command_args: dict[str, str] = {},
|
||||
user_input: str = "",
|
||||
proposal: BaseAgentActionProposal,
|
||||
user_feedback: str = "",
|
||||
) -> ActionResult:
|
||||
"""Executes the given command, if any, and returns the agent's response.
|
||||
|
||||
Params:
|
||||
command_name: The name of the command to execute, if any.
|
||||
command_args: The arguments to pass to the command, if any.
|
||||
user_input: The user's input, if any.
|
||||
|
||||
Returns:
|
||||
ActionResult: An object representing the result(s) of the command.
|
||||
"""
|
||||
...
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
scratchpad: PromptScratchpad,
|
||||
extra_commands: Optional[list[CompletionModelFunction]] = None,
|
||||
extra_messages: Optional[list[ChatMessage]] = None,
|
||||
**extras,
|
||||
) -> ChatPrompt:
|
||||
"""Constructs a prompt using `self.prompt_strategy`.
|
||||
|
||||
Params:
|
||||
scratchpad: An object for plugins to write additional prompt elements to.
|
||||
(E.g. commands, constraints, best practices)
|
||||
extra_commands: Additional commands that the agent has access to.
|
||||
extra_messages: Additional messages to include in the prompt.
|
||||
"""
|
||||
if not extra_commands:
|
||||
extra_commands = []
|
||||
if not extra_messages:
|
||||
extra_messages = []
|
||||
|
||||
# Apply additions from plugins
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
plugin.post_prompt(scratchpad)
|
||||
ai_directives = self.directives.copy(deep=True)
|
||||
ai_directives.resources += scratchpad.resources
|
||||
ai_directives.constraints += scratchpad.constraints
|
||||
ai_directives.best_practices += scratchpad.best_practices
|
||||
extra_commands += list(scratchpad.commands.values())
|
||||
|
||||
prompt = self.prompt_strategy.build_prompt(
|
||||
task=self.state.task,
|
||||
ai_profile=self.ai_profile,
|
||||
ai_directives=ai_directives,
|
||||
commands=get_openai_command_specs(
|
||||
self.command_registry.list_available_commands(self)
|
||||
)
|
||||
+ extra_commands,
|
||||
event_history=self.event_history,
|
||||
max_prompt_tokens=self.send_token_limit,
|
||||
count_tokens=lambda x: self.llm_provider.count_tokens(x, self.llm.name),
|
||||
count_message_tokens=lambda x: self.llm_provider.count_message_tokens(
|
||||
x, self.llm.name
|
||||
),
|
||||
extra_messages=extra_messages,
|
||||
**extras,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
def on_before_think(
|
||||
self,
|
||||
prompt: ChatPrompt,
|
||||
scratchpad: PromptScratchpad,
|
||||
) -> ChatPrompt:
|
||||
"""Called after constructing the prompt but before executing it.
|
||||
|
||||
Calls the `on_planning` hook of any enabled and capable plugins, adding their
|
||||
output to the prompt.
|
||||
|
||||
Params:
|
||||
prompt: The prompt that is about to be executed.
|
||||
scratchpad: An object for plugins to write additional prompt elements to.
|
||||
(E.g. commands, constraints, best practices)
|
||||
|
||||
Returns:
|
||||
The prompt to execute
|
||||
"""
|
||||
current_tokens_used = self.llm_provider.count_message_tokens(
|
||||
prompt.messages, self.llm.name
|
||||
)
|
||||
plugin_count = len(self.config.plugins)
|
||||
for i, plugin in enumerate(self.config.plugins):
|
||||
if not plugin.can_handle_on_planning():
|
||||
continue
|
||||
plugin_response = plugin.on_planning(scratchpad, prompt.raw())
|
||||
if not plugin_response or plugin_response == "":
|
||||
continue
|
||||
message_to_add = ChatMessage.system(plugin_response)
|
||||
tokens_to_add = self.llm_provider.count_message_tokens(
|
||||
message_to_add, self.llm.name
|
||||
)
|
||||
if current_tokens_used + tokens_to_add > self.send_token_limit:
|
||||
logger.debug(f"Plugin response too long, skipping: {plugin_response}")
|
||||
logger.debug(f"Plugins remaining at stop: {plugin_count - i}")
|
||||
break
|
||||
prompt.messages.insert(
|
||||
-1, message_to_add
|
||||
) # HACK: assumes cycle instruction to be at the end
|
||||
current_tokens_used += tokens_to_add
|
||||
return prompt
|
||||
|
||||
def on_response(
|
||||
self,
|
||||
llm_response: ChatModelResponse,
|
||||
prompt: ChatPrompt,
|
||||
scratchpad: PromptScratchpad,
|
||||
) -> ThoughtProcessOutput:
|
||||
"""Called upon receiving a response from the chat model.
|
||||
|
||||
Calls `self.parse_and_process_response()`.
|
||||
|
||||
Params:
|
||||
llm_response: The raw response from the chat model.
|
||||
prompt: The prompt that was executed.
|
||||
scratchpad: An object containing additional prompt elements from plugins.
|
||||
(E.g. commands, constraints, best practices)
|
||||
|
||||
Returns:
|
||||
The parsed command name and command args, if any, and the agent thoughts.
|
||||
"""
|
||||
|
||||
return llm_response.parsed_result
|
||||
|
||||
# TODO: update memory/context
|
||||
|
||||
@abstractmethod
|
||||
def parse_and_process_response(
|
||||
async def do_not_execute(
|
||||
self,
|
||||
llm_response: AssistantChatMessage,
|
||||
prompt: ChatPrompt,
|
||||
scratchpad: PromptScratchpad,
|
||||
) -> ThoughtProcessOutput:
|
||||
"""Validate, parse & process the LLM's response.
|
||||
denied_proposal: BaseAgentActionProposal,
|
||||
user_feedback: str,
|
||||
) -> ActionResult:
|
||||
...
|
||||
|
||||
Must be implemented by derivative classes: no base implementation is provided,
|
||||
since the implementation depends on the role of the derivative Agent.
|
||||
def reset_trace(self):
|
||||
self._trace = []
|
||||
|
||||
Params:
|
||||
llm_response: The raw response from the chat model.
|
||||
prompt: The prompt that was executed.
|
||||
scratchpad: An object containing additional prompt elements from plugins.
|
||||
(E.g. commands, constraints, best practices)
|
||||
@overload
|
||||
async def run_pipeline(
|
||||
self, protocol_method: Callable[P, Iterator[T]], *args, retry_limit: int = 3
|
||||
) -> list[T]:
|
||||
...
|
||||
|
||||
Returns:
|
||||
The parsed command name and command args, if any, and the agent thoughts.
|
||||
"""
|
||||
pass
|
||||
@overload
|
||||
async def run_pipeline(
|
||||
self, protocol_method: Callable[P, None], *args, retry_limit: int = 3
|
||||
) -> list[None]:
|
||||
...
|
||||
|
||||
async def run_pipeline(
|
||||
self,
|
||||
protocol_method: Callable[P, Iterator[T] | None],
|
||||
*args,
|
||||
retry_limit: int = 3,
|
||||
) -> list[T] | list[None]:
|
||||
method_name = protocol_method.__name__
|
||||
protocol_name = protocol_method.__qualname__.split(".")[0]
|
||||
protocol_class = getattr(_protocols, protocol_name)
|
||||
if not issubclass(protocol_class, AgentComponent):
|
||||
raise TypeError(f"{repr(protocol_method)} is not a protocol method")
|
||||
|
||||
# Clone parameters to revert on failure
|
||||
original_args = self._selective_copy(args)
|
||||
pipeline_attempts = 0
|
||||
method_result: list[T] = []
|
||||
self._trace.append(f"⬇️ {Fore.BLUE}{method_name}{Fore.RESET}")
|
||||
|
||||
while pipeline_attempts < retry_limit:
|
||||
try:
|
||||
for component in self.components:
|
||||
# Skip other protocols
|
||||
if not isinstance(component, protocol_class):
|
||||
continue
|
||||
|
||||
# Skip disabled components
|
||||
if not component.enabled:
|
||||
self._trace.append(
|
||||
f" {Fore.LIGHTBLACK_EX}"
|
||||
f"{component.__class__.__name__}{Fore.RESET}"
|
||||
)
|
||||
continue
|
||||
|
||||
method = getattr(component, method_name, None)
|
||||
if not callable(method):
|
||||
continue
|
||||
|
||||
component_attempts = 0
|
||||
while component_attempts < retry_limit:
|
||||
try:
|
||||
component_args = self._selective_copy(args)
|
||||
if inspect.iscoroutinefunction(method):
|
||||
result = await method(*component_args)
|
||||
else:
|
||||
result = method(*component_args)
|
||||
if result is not None:
|
||||
method_result.extend(result)
|
||||
args = component_args
|
||||
self._trace.append(f"✅ {component.__class__.__name__}")
|
||||
|
||||
except ComponentEndpointError:
|
||||
self._trace.append(
|
||||
f"❌ {Fore.YELLOW}{component.__class__.__name__}: "
|
||||
f"ComponentEndpointError{Fore.RESET}"
|
||||
)
|
||||
# Retry the same component on ComponentEndpointError
|
||||
component_attempts += 1
|
||||
continue
|
||||
# Successful component execution
|
||||
break
|
||||
# Successful pipeline execution
|
||||
break
|
||||
except EndpointPipelineError:
|
||||
self._trace.append(
|
||||
f"❌ {Fore.LIGHTRED_EX}{component.__class__.__name__}: "
|
||||
f"EndpointPipelineError{Fore.RESET}"
|
||||
)
|
||||
# Restart from the beginning on EndpointPipelineError
|
||||
# Revert to original parameters
|
||||
args = self._selective_copy(original_args)
|
||||
pipeline_attempts += 1
|
||||
continue # Start the loop over
|
||||
except Exception as e:
|
||||
raise e
|
||||
return method_result
|
||||
|
||||
def _collect_components(self):
|
||||
components = [
|
||||
getattr(self, attr)
|
||||
for attr in dir(self)
|
||||
if isinstance(getattr(self, attr), AgentComponent)
|
||||
]
|
||||
|
||||
if self.components:
|
||||
# Check if any coponent is missed (added to Agent but not to components)
|
||||
for component in components:
|
||||
if component not in self.components:
|
||||
logger.warning(
|
||||
f"Component {component.__class__.__name__} "
|
||||
"is attached to an agent but not added to components list"
|
||||
)
|
||||
# Skip collecting anf sorting and sort if ordering is explicit
|
||||
return
|
||||
self.components = self._topological_sort(components)
|
||||
|
||||
def _topological_sort(
|
||||
self, components: list[AgentComponent]
|
||||
) -> list[AgentComponent]:
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
def visit(node: AgentComponent):
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
for neighbor_class in node.__class__.run_after:
|
||||
# Find the instance of neighbor_class in components
|
||||
neighbor = next(
|
||||
(m for m in components if isinstance(m, neighbor_class)), None
|
||||
)
|
||||
if neighbor:
|
||||
visit(neighbor)
|
||||
stack.append(node)
|
||||
|
||||
for component in components:
|
||||
visit(component)
|
||||
|
||||
return stack
|
||||
|
||||
def _selective_copy(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
copied_args = []
|
||||
for item in args:
|
||||
if isinstance(item, list):
|
||||
# Shallow copy for lists
|
||||
copied_item = item[:]
|
||||
elif isinstance(item, dict):
|
||||
# Shallow copy for dicts
|
||||
copied_item = item.copy()
|
||||
elif isinstance(item, BaseModel):
|
||||
# Deep copy for Pydantic models (deep=True to also copy nested models)
|
||||
copied_item = item.copy(deep=True)
|
||||
else:
|
||||
# Deep copy for other objects
|
||||
copied_item = copy.deepcopy(item)
|
||||
copied_args.append(copied_item)
|
||||
return tuple(copied_args)
|
||||
|
||||
35
autogpts/autogpt/autogpt/agents/components.py
Normal file
35
autogpts/autogpt/autogpt/agents/components.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from abc import ABC
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class AgentComponent(ABC):
|
||||
run_after: list[type["AgentComponent"]] = []
|
||||
_enabled: Callable[[], bool] | bool = True
|
||||
_disabled_reason: str = ""
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
if callable(self._enabled):
|
||||
return self._enabled()
|
||||
return self._enabled
|
||||
|
||||
@property
|
||||
def disabled_reason(self) -> str:
|
||||
return self._disabled_reason
|
||||
|
||||
|
||||
class ComponentEndpointError(Exception):
|
||||
"""Error of a single protocol method on a component."""
|
||||
|
||||
def __init__(self, message: str = ""):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class EndpointPipelineError(ComponentEndpointError):
|
||||
"""Error of an entire pipline of one endpoint."""
|
||||
|
||||
|
||||
class ComponentSystemError(EndpointPipelineError):
|
||||
"""Error of a group of pipelines;
|
||||
multiple different enpoints."""
|
||||
@@ -1,18 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from autogpt.agents.protocols import CommandProvider, DirectiveProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.utils.file_operations_utils import decode_textual_file
|
||||
|
||||
from ..base import BaseAgent, BaseAgentSettings
|
||||
from ..base import BaseAgentSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentFileManagerMixin:
|
||||
"""Mixin that adds file manager (e.g. Agent state)
|
||||
and workspace manager (e.g. Agent output files) support."""
|
||||
class FileManagerComponent(DirectiveProvider, CommandProvider):
|
||||
"""
|
||||
Adds general file manager (e.g. Agent state),
|
||||
workspace manager (e.g. Agent output files) support and
|
||||
commands to perform operations on files and folders.
|
||||
"""
|
||||
|
||||
files: FileStorage
|
||||
"""Agent-related files, e.g. state, logs.
|
||||
@@ -25,49 +33,17 @@ class AgentFileManagerMixin:
|
||||
STATE_FILE = "state.json"
|
||||
"""The name of the file where the agent's state is stored."""
|
||||
|
||||
LOGS_FILE = "file_logger.log"
|
||||
"""The name of the file where the agent's logs are stored."""
|
||||
def __init__(self, state: BaseAgentSettings, file_storage: FileStorage):
|
||||
self.state = state
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Initialize other bases first, because we need the config from BaseAgent
|
||||
super(AgentFileManagerMixin, self).__init__(**kwargs)
|
||||
|
||||
if not isinstance(self, BaseAgent):
|
||||
raise NotImplementedError(
|
||||
f"{__class__.__name__} can only be applied to BaseAgent derivatives"
|
||||
)
|
||||
|
||||
if "file_storage" not in kwargs:
|
||||
raise ValueError(
|
||||
"AgentFileManagerMixin requires a file_storage in the constructor."
|
||||
)
|
||||
|
||||
state: BaseAgentSettings = getattr(self, "state")
|
||||
if not state.agent_id:
|
||||
raise ValueError("Agent must have an ID.")
|
||||
|
||||
file_storage: FileStorage = kwargs["file_storage"]
|
||||
self.files = file_storage.clone_with_subroot(f"agents/{state.agent_id}/")
|
||||
self.workspace = file_storage.clone_with_subroot(
|
||||
f"agents/{state.agent_id}/workspace"
|
||||
)
|
||||
self._file_storage = file_storage
|
||||
# Read and cache logs
|
||||
self._file_logs_cache = []
|
||||
if self.files.exists(self.LOGS_FILE):
|
||||
self._file_logs_cache = self.files.read_file(self.LOGS_FILE).split("\n")
|
||||
|
||||
async def log_file_operation(self, content: str) -> None:
|
||||
"""Log a file operation to the agent's log file."""
|
||||
logger.debug(f"Logging operation: {content}")
|
||||
self._file_logs_cache.append(content)
|
||||
await self.files.write_file(
|
||||
self.LOGS_FILE, "\n".join(self._file_logs_cache) + "\n"
|
||||
)
|
||||
|
||||
def get_file_operation_lines(self) -> list[str]:
|
||||
"""Get the agent's file operation logs as list of strings."""
|
||||
return self._file_logs_cache
|
||||
|
||||
async def save_state(self, save_as: Optional[str] = None) -> None:
|
||||
"""Save the agent's state to the state file."""
|
||||
@@ -100,3 +76,86 @@ class AgentFileManagerMixin:
|
||||
f"agents/{new_id}/workspace"
|
||||
)
|
||||
state.agent_id = new_id
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield "The ability to read and write files."
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.read_file
|
||||
yield self.write_to_file
|
||||
yield self.list_folder
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the file to read",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def read_file(self, filename: str | Path) -> str:
|
||||
"""Read a file and return the contents
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to read
|
||||
|
||||
Returns:
|
||||
str: The contents of the file
|
||||
"""
|
||||
file = self.workspace.open_file(filename, binary=True)
|
||||
content = decode_textual_file(file, os.path.splitext(filename)[1], logger)
|
||||
|
||||
return content
|
||||
|
||||
@command(
|
||||
["write_file", "create_file"],
|
||||
"Write a file, creating it if necessary. "
|
||||
"If the file exists, it is overwritten.",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The name of the file to write to",
|
||||
required=True,
|
||||
),
|
||||
"contents": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The contents to write to the file",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
async def write_to_file(self, filename: str | Path, contents: str) -> str:
|
||||
"""Write contents to a file
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to write to
|
||||
contents (str): The contents to write to the file
|
||||
|
||||
Returns:
|
||||
str: A message indicating success or failure
|
||||
"""
|
||||
if directory := os.path.dirname(filename):
|
||||
self.workspace.make_dir(directory)
|
||||
await self.workspace.write_file(filename, contents)
|
||||
return f"File {filename} has been written successfully."
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"folder": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The folder to list files in",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def list_folder(self, folder: str | Path) -> list[str]:
|
||||
"""Lists files in a folder recursively
|
||||
|
||||
Args:
|
||||
folder (str): The folder to search in
|
||||
|
||||
Returns:
|
||||
list[str]: A list of files found in the folder
|
||||
"""
|
||||
return [str(p) for p in self.workspace.list_files(folder)]
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.core.prompting import ChatPrompt
|
||||
from autogpt.models.context_item import ContextItem
|
||||
|
||||
from ..base import BaseAgent
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from autogpt.agents.protocols import CommandProvider, MessageProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.core.resource.model_providers import ChatMessage
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.models.context_item import ContextItem, FileContextItem, FolderContextItem
|
||||
from autogpt.utils.exceptions import InvalidArgumentError
|
||||
|
||||
|
||||
class AgentContext:
|
||||
@@ -32,51 +33,129 @@ class AgentContext:
|
||||
def clear(self) -> None:
|
||||
self.items.clear()
|
||||
|
||||
def format_numbered(self) -> str:
|
||||
return "\n\n".join([f"{i}. {c.fmt()}" for i, c in enumerate(self.items, 1)])
|
||||
def format_numbered(self, workspace: FileStorage) -> str:
|
||||
return "\n\n".join(
|
||||
[f"{i}. {c.fmt(workspace)}" for i, c in enumerate(self.items, 1)]
|
||||
)
|
||||
|
||||
|
||||
class ContextMixin:
|
||||
"""Mixin that adds context support to a BaseAgent subclass"""
|
||||
class ContextComponent(MessageProvider, CommandProvider):
|
||||
"""Adds ability to keep files and folders open in the context (prompt)."""
|
||||
|
||||
context: AgentContext
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, workspace: FileStorage):
|
||||
self.context = AgentContext()
|
||||
self.workspace = workspace
|
||||
|
||||
super(ContextMixin, self).__init__(**kwargs)
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
*args: Any,
|
||||
extra_messages: Optional[list[ChatMessage]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatPrompt:
|
||||
if not extra_messages:
|
||||
extra_messages = []
|
||||
|
||||
# Add context section to prompt
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
if self.context:
|
||||
extra_messages.insert(
|
||||
0,
|
||||
ChatMessage.system(
|
||||
"## Context\n"
|
||||
f"{self.context.format_numbered()}\n\n"
|
||||
"When a context item is no longer needed and you are not done yet, "
|
||||
"you can hide the item by specifying its number in the list above "
|
||||
"to `hide_context_item`.",
|
||||
),
|
||||
yield ChatMessage.system(
|
||||
"## Context\n"
|
||||
f"{self.context.format_numbered(self.workspace)}\n\n"
|
||||
"When a context item is no longer needed and you are not done yet, "
|
||||
"you can hide the item by specifying its number in the list above "
|
||||
"to `hide_context_item`.",
|
||||
)
|
||||
|
||||
return super(ContextMixin, self).build_prompt(
|
||||
*args,
|
||||
extra_messages=extra_messages,
|
||||
**kwargs,
|
||||
) # type: ignore
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.open_file
|
||||
yield self.open_folder
|
||||
if self.context:
|
||||
yield self.close_context_item
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"file_path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the file to open",
|
||||
required=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
async def open_file(self, file_path: str | Path) -> str:
|
||||
"""Opens a file for editing or continued viewing;
|
||||
creates it if it does not exist yet.
|
||||
Note: If you only need to read or write a file once,
|
||||
use `write_to_file` instead.
|
||||
|
||||
def get_agent_context(agent: BaseAgent) -> AgentContext | None:
|
||||
if isinstance(agent, ContextMixin):
|
||||
return agent.context
|
||||
Args:
|
||||
file_path (str | Path): The path of the file to open
|
||||
|
||||
return None
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
|
||||
created = False
|
||||
if not self.workspace.exists(file_path):
|
||||
await self.workspace.write_file(file_path, "")
|
||||
created = True
|
||||
|
||||
# Try to make the file path relative
|
||||
with contextlib.suppress(ValueError):
|
||||
file_path = file_path.relative_to(self.workspace.root)
|
||||
|
||||
file = FileContextItem(path=file_path)
|
||||
self.context.add(file)
|
||||
return (
|
||||
f"File {file_path}{' created,' if created else ''} has been opened"
|
||||
" and added to the context ✅"
|
||||
)
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the folder to open",
|
||||
required=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
def open_folder(self, path: str | Path) -> str:
|
||||
"""Open a folder to keep track of its content
|
||||
|
||||
Args:
|
||||
path (str | Path): The path of the folder to open
|
||||
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
"""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
|
||||
if not self.workspace.exists(path):
|
||||
raise FileNotFoundError(
|
||||
f"open_folder {path} failed: no such file or directory"
|
||||
)
|
||||
|
||||
# Try to make the path relative
|
||||
with contextlib.suppress(ValueError):
|
||||
path = path.relative_to(self.workspace.root)
|
||||
|
||||
folder = FolderContextItem(path=path)
|
||||
self.context.add(folder)
|
||||
return f"Folder {path} has been opened and added to the context ✅"
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"number": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The 1-based index of the context item to hide",
|
||||
required=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
def close_context_item(self, number: int) -> str:
|
||||
"""Hide an open file, folder or other context item, to save tokens.
|
||||
|
||||
Args:
|
||||
number (int): The 1-based index of the context item to hide
|
||||
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
"""
|
||||
if number > len(self.context.items) or number == 0:
|
||||
raise InvalidArgumentError(f"Index {number} out of range")
|
||||
|
||||
self.context.close(number)
|
||||
return f"Context item {number} hidden ✅"
|
||||
|
||||
@@ -1,41 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import ExitStack
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..base import BaseAgentConfiguration
|
||||
|
||||
from autogpt.agents.base import BaseAgentActionProposal, BaseAgentConfiguration
|
||||
from autogpt.agents.components import ComponentSystemError
|
||||
from autogpt.agents.features.context import ContextComponent
|
||||
from autogpt.agents.protocols import AfterParse
|
||||
from autogpt.models.action_history import EpisodicActionHistory
|
||||
|
||||
from ..base import BaseAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WatchdogMixin:
|
||||
class WatchdogComponent(AfterParse):
|
||||
"""
|
||||
Mixin that adds a watchdog feature to an agent class. Whenever the agent starts
|
||||
Adds a watchdog feature to an agent class. Whenever the agent starts
|
||||
looping, the watchdog will switch from the FAST_LLM to the SMART_LLM and re-think.
|
||||
"""
|
||||
|
||||
config: BaseAgentConfiguration
|
||||
event_history: EpisodicActionHistory
|
||||
run_after = [ContextComponent]
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
# Initialize other bases first, because we need the event_history from BaseAgent
|
||||
super(WatchdogMixin, self).__init__(**kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseAgentConfiguration,
|
||||
event_history: EpisodicActionHistory[BaseAgentActionProposal],
|
||||
):
|
||||
self.config = config
|
||||
self.event_history = event_history
|
||||
self.revert_big_brain = False
|
||||
|
||||
if not isinstance(self, BaseAgent):
|
||||
raise NotImplementedError(
|
||||
f"{__class__.__name__} can only be applied to BaseAgent derivatives"
|
||||
)
|
||||
|
||||
async def propose_action(self, *args, **kwargs) -> BaseAgent.ThoughtProcessOutput:
|
||||
command_name, command_args, thoughts = await super(
|
||||
WatchdogMixin, self
|
||||
).propose_action(*args, **kwargs)
|
||||
def after_parse(self, result: BaseAgentActionProposal) -> None:
|
||||
if self.revert_big_brain:
|
||||
self.config.big_brain = False
|
||||
self.revert_big_brain = False
|
||||
|
||||
if not self.config.big_brain and self.config.fast_llm != self.config.smart_llm:
|
||||
previous_command, previous_command_args = None, None
|
||||
@@ -44,33 +38,23 @@ class WatchdogMixin:
|
||||
previous_cycle = self.event_history.episodes[
|
||||
self.event_history.cursor - 1
|
||||
]
|
||||
previous_command = previous_cycle.action.name
|
||||
previous_command_args = previous_cycle.action.args
|
||||
previous_command = previous_cycle.action.use_tool.name
|
||||
previous_command_args = previous_cycle.action.use_tool.arguments
|
||||
|
||||
rethink_reason = ""
|
||||
|
||||
if not command_name:
|
||||
if not result.use_tool:
|
||||
rethink_reason = "AI did not specify a command"
|
||||
elif (
|
||||
command_name == previous_command
|
||||
and command_args == previous_command_args
|
||||
result.use_tool.name == previous_command
|
||||
and result.use_tool.arguments == previous_command_args
|
||||
):
|
||||
rethink_reason = f"Repititive command detected ({command_name})"
|
||||
rethink_reason = f"Repititive command detected ({result.use_tool.name})"
|
||||
|
||||
if rethink_reason:
|
||||
logger.info(f"{rethink_reason}, re-thinking with SMART_LLM...")
|
||||
with ExitStack() as stack:
|
||||
|
||||
@stack.callback
|
||||
def restore_state() -> None:
|
||||
# Executed after exiting the ExitStack context
|
||||
self.config.big_brain = False
|
||||
|
||||
# Remove partial record of current cycle
|
||||
self.event_history.rewind()
|
||||
|
||||
# Switch to SMART_LLM and re-think
|
||||
self.big_brain = True
|
||||
return await self.propose_action(*args, **kwargs)
|
||||
|
||||
return command_name, command_args, thoughts
|
||||
self.event_history.rewind()
|
||||
self.big_brain = True
|
||||
self.revert_big_brain = True
|
||||
# Trigger retry of all pipelines prior to this component
|
||||
raise ComponentSystemError()
|
||||
|
||||
@@ -4,15 +4,11 @@ import json
|
||||
import platform
|
||||
import re
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
import distro
|
||||
from pydantic import Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.models.action_history import Episode
|
||||
|
||||
from autogpt.agents.utils.exceptions import InvalidAgentResponseError
|
||||
from autogpt.agents.base import BaseAgentActionProposal
|
||||
from autogpt.config import AIDirectives, AIProfile
|
||||
from autogpt.core.configuration.schema import SystemConfiguration, UserConfigurable
|
||||
from autogpt.core.prompting import (
|
||||
@@ -27,7 +23,31 @@ from autogpt.core.resource.model_providers.schema import (
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import extract_dict_from_json
|
||||
from autogpt.prompts.utils import format_numbered_list, indent
|
||||
from autogpt.models.utils import ModelWithSummary
|
||||
from autogpt.prompts.utils import format_numbered_list
|
||||
from autogpt.utils.exceptions import InvalidAgentResponseError
|
||||
|
||||
_RESPONSE_INTERFACE_NAME = "AssistantResponse"
|
||||
|
||||
|
||||
class AssistantThoughts(ModelWithSummary):
|
||||
observations: str = Field(
|
||||
..., description="Relevant observations from your last action (if any)"
|
||||
)
|
||||
text: str = Field(..., description="Thoughts")
|
||||
reasoning: str = Field(..., description="Reasoning behind the thoughts")
|
||||
self_criticism: str = Field(..., description="Constructive self-criticism")
|
||||
plan: list[str] = Field(
|
||||
..., description="Short list that conveys the long-term plan"
|
||||
)
|
||||
speak: str = Field(..., description="Summary of thoughts, to say to user")
|
||||
|
||||
def summary(self) -> str:
|
||||
return self.text
|
||||
|
||||
|
||||
class OneShotAgentActionProposal(BaseAgentActionProposal):
|
||||
thoughts: AssistantThoughts
|
||||
|
||||
|
||||
class OneShotAgentPromptConfiguration(SystemConfiguration):
|
||||
@@ -55,70 +75,7 @@ class OneShotAgentPromptConfiguration(SystemConfiguration):
|
||||
"and respond using the JSON schema specified previously:"
|
||||
)
|
||||
|
||||
DEFAULT_RESPONSE_SCHEMA = JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={
|
||||
"thoughts": JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
required=True,
|
||||
properties={
|
||||
"observations": JSONSchema(
|
||||
description=(
|
||||
"Relevant observations from your last action (if any)"
|
||||
),
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=False,
|
||||
),
|
||||
"text": JSONSchema(
|
||||
description="Thoughts",
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
"reasoning": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
"self_criticism": JSONSchema(
|
||||
description="Constructive self-criticism",
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
"plan": JSONSchema(
|
||||
description=(
|
||||
"Short markdown-style bullet list that conveys the "
|
||||
"long-term plan"
|
||||
),
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
"speak": JSONSchema(
|
||||
description="Summary of thoughts, to say to user",
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
"command": JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
required=True,
|
||||
properties={
|
||||
"name": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
"args": JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
body_template: str = UserConfigurable(default=DEFAULT_BODY_TEMPLATE)
|
||||
response_schema: dict = UserConfigurable(
|
||||
default_factory=DEFAULT_RESPONSE_SCHEMA.to_dict
|
||||
)
|
||||
choose_action_instruction: str = UserConfigurable(
|
||||
default=DEFAULT_CHOOSE_ACTION_INSTRUCTION
|
||||
)
|
||||
@@ -143,7 +100,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
logger: Logger,
|
||||
):
|
||||
self.config = configuration
|
||||
self.response_schema = JSONSchema.from_dict(configuration.response_schema)
|
||||
self.response_schema = JSONSchema.from_dict(OneShotAgentActionProposal.schema())
|
||||
self.logger = logger
|
||||
|
||||
@property
|
||||
@@ -153,81 +110,55 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
def build_prompt(
|
||||
self,
|
||||
*,
|
||||
messages: list[ChatMessage],
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
event_history: list[Episode],
|
||||
include_os_info: bool,
|
||||
max_prompt_tokens: int,
|
||||
count_tokens: Callable[[str], int],
|
||||
count_message_tokens: Callable[[ChatMessage | list[ChatMessage]], int],
|
||||
extra_messages: Optional[list[ChatMessage]] = None,
|
||||
**extras,
|
||||
) -> ChatPrompt:
|
||||
"""Constructs and returns a prompt with the following structure:
|
||||
1. System prompt
|
||||
2. Message history of the agent, truncated & prepended with running summary
|
||||
as needed
|
||||
3. `cycle_instruction`
|
||||
"""
|
||||
if not extra_messages:
|
||||
extra_messages = []
|
||||
|
||||
system_prompt = self.build_system_prompt(
|
||||
system_prompt, response_prefill = self.build_system_prompt(
|
||||
ai_profile=ai_profile,
|
||||
ai_directives=ai_directives,
|
||||
commands=commands,
|
||||
include_os_info=include_os_info,
|
||||
)
|
||||
system_prompt_tlength = count_message_tokens(ChatMessage.system(system_prompt))
|
||||
|
||||
user_task = f'"""{task}"""'
|
||||
user_task_tlength = count_message_tokens(ChatMessage.user(user_task))
|
||||
|
||||
response_format_instr = self.response_format_instruction(
|
||||
self.config.use_functions_api
|
||||
)
|
||||
extra_messages.append(ChatMessage.system(response_format_instr))
|
||||
|
||||
final_instruction_msg = ChatMessage.user(self.config.choose_action_instruction)
|
||||
final_instruction_tlength = count_message_tokens(final_instruction_msg)
|
||||
|
||||
if event_history:
|
||||
progress = self.compile_progress(
|
||||
event_history,
|
||||
count_tokens=count_tokens,
|
||||
max_tokens=(
|
||||
max_prompt_tokens
|
||||
- system_prompt_tlength
|
||||
- user_task_tlength
|
||||
- final_instruction_tlength
|
||||
- count_message_tokens(extra_messages)
|
||||
),
|
||||
)
|
||||
extra_messages.insert(
|
||||
0,
|
||||
ChatMessage.system(f"## Progress\n\n{progress}"),
|
||||
)
|
||||
|
||||
prompt = ChatPrompt(
|
||||
return ChatPrompt(
|
||||
messages=[
|
||||
ChatMessage.system(system_prompt),
|
||||
ChatMessage.user(user_task),
|
||||
*extra_messages,
|
||||
ChatMessage.user(f'"""{task}"""'),
|
||||
*messages,
|
||||
final_instruction_msg,
|
||||
],
|
||||
prefill_response=response_prefill,
|
||||
functions=commands if self.config.use_functions_api else [],
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
def build_system_prompt(
|
||||
self,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
include_os_info: bool,
|
||||
) -> str:
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds the system prompt.
|
||||
|
||||
Returns:
|
||||
str: The system prompt body
|
||||
str: The desired start for the LLM's response; used to steer the output
|
||||
"""
|
||||
response_fmt_instruction, response_prefill = self.response_format_instruction(
|
||||
self.config.use_functions_api
|
||||
)
|
||||
system_prompt_parts = (
|
||||
self._generate_intro_prompt(ai_profile)
|
||||
+ (self._generate_os_info() if include_os_info else [])
|
||||
@@ -248,69 +179,39 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
" in the next message. Your job is to complete the task while following"
|
||||
" your directives as given above, and terminate when your task is done."
|
||||
]
|
||||
+ ["## RESPONSE FORMAT\n" + response_fmt_instruction]
|
||||
)
|
||||
|
||||
# Join non-empty parts together into paragraph format
|
||||
return "\n\n".join(filter(None, system_prompt_parts)).strip("\n")
|
||||
return (
|
||||
"\n\n".join(filter(None, system_prompt_parts)).strip("\n"),
|
||||
response_prefill,
|
||||
)
|
||||
|
||||
def compile_progress(
|
||||
self,
|
||||
episode_history: list[Episode],
|
||||
max_tokens: Optional[int] = None,
|
||||
count_tokens: Optional[Callable[[str], int]] = None,
|
||||
) -> str:
|
||||
if max_tokens and not count_tokens:
|
||||
raise ValueError("count_tokens is required if max_tokens is set")
|
||||
|
||||
steps: list[str] = []
|
||||
tokens: int = 0
|
||||
n_episodes = len(episode_history)
|
||||
|
||||
for i, episode in enumerate(reversed(episode_history)):
|
||||
# Use full format for the latest 4 steps, summary or format for older steps
|
||||
if i < 4 or episode.summary is None:
|
||||
step_content = indent(episode.format(), 2).strip()
|
||||
else:
|
||||
step_content = episode.summary
|
||||
|
||||
step = f"* Step {n_episodes - i}: {step_content}"
|
||||
|
||||
if max_tokens and count_tokens:
|
||||
step_tokens = count_tokens(step)
|
||||
if tokens + step_tokens > max_tokens:
|
||||
break
|
||||
tokens += step_tokens
|
||||
|
||||
steps.insert(0, step)
|
||||
|
||||
return "\n\n".join(steps)
|
||||
|
||||
def response_format_instruction(self, use_functions_api: bool) -> str:
|
||||
def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]:
|
||||
response_schema = self.response_schema.copy(deep=True)
|
||||
if (
|
||||
use_functions_api
|
||||
and response_schema.properties
|
||||
and "command" in response_schema.properties
|
||||
and "use_tool" in response_schema.properties
|
||||
):
|
||||
del response_schema.properties["command"]
|
||||
del response_schema.properties["use_tool"]
|
||||
|
||||
# Unindent for performance
|
||||
response_format = re.sub(
|
||||
r"\n\s+",
|
||||
"\n",
|
||||
response_schema.to_typescript_object_interface("Response"),
|
||||
)
|
||||
|
||||
instruction = (
|
||||
"Respond with pure JSON containing your thoughts, " "and invoke a tool."
|
||||
if use_functions_api
|
||||
else "Respond with pure JSON."
|
||||
response_schema.to_typescript_object_interface(_RESPONSE_INTERFACE_NAME),
|
||||
)
|
||||
response_prefill = f'{{\n "{list(response_schema.properties.keys())[0]}":'
|
||||
|
||||
return (
|
||||
f"{instruction} "
|
||||
"The JSON object should be compatible with the TypeScript type `Response` "
|
||||
f"from the following:\n{response_format}"
|
||||
(
|
||||
f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n"
|
||||
f"{response_format}"
|
||||
+ ("\n\nYOU MUST ALSO INVOKE A TOOL!" if use_functions_api else "")
|
||||
),
|
||||
response_prefill,
|
||||
)
|
||||
|
||||
def _generate_intro_prompt(self, ai_profile: AIProfile) -> list[str]:
|
||||
@@ -374,7 +275,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
def parse_response_content(
|
||||
self,
|
||||
response: AssistantChatMessage,
|
||||
) -> Agent.ThoughtProcessOutput:
|
||||
) -> OneShotAgentActionProposal:
|
||||
if not response.content:
|
||||
raise InvalidAgentResponseError("Assistant response has no text content")
|
||||
|
||||
@@ -388,81 +289,13 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
)
|
||||
assistant_reply_dict = extract_dict_from_json(response.content)
|
||||
self.logger.debug(
|
||||
"Validating object extracted from LLM response:\n"
|
||||
"Parsing object extracted from LLM response:\n"
|
||||
f"{json.dumps(assistant_reply_dict, indent=4)}"
|
||||
)
|
||||
|
||||
_, errors = self.response_schema.validate_object(assistant_reply_dict)
|
||||
if errors:
|
||||
raise InvalidAgentResponseError(
|
||||
"Validation of response failed:\n "
|
||||
+ ";\n ".join([str(e) for e in errors])
|
||||
)
|
||||
|
||||
# Get command name and arguments
|
||||
command_name, arguments = extract_command(
|
||||
assistant_reply_dict, response, self.config.use_functions_api
|
||||
)
|
||||
return command_name, arguments, assistant_reply_dict
|
||||
|
||||
|
||||
#############
|
||||
# Utilities #
|
||||
#############
|
||||
|
||||
|
||||
def extract_command(
|
||||
assistant_reply_json: dict,
|
||||
assistant_reply: AssistantChatMessage,
|
||||
use_openai_functions_api: bool,
|
||||
) -> tuple[str, dict[str, str]]:
|
||||
"""Parse the response and return the command name and arguments
|
||||
|
||||
Args:
|
||||
assistant_reply_json (dict): The response object from the AI
|
||||
assistant_reply (AssistantChatMessage): The model response from the AI
|
||||
config (Config): The config object
|
||||
|
||||
Returns:
|
||||
tuple: The command name and arguments
|
||||
|
||||
Raises:
|
||||
json.decoder.JSONDecodeError: If the response is not valid JSON
|
||||
|
||||
Exception: If any other error occurs
|
||||
"""
|
||||
if use_openai_functions_api:
|
||||
if not assistant_reply.tool_calls:
|
||||
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
|
||||
assistant_reply_json["command"] = {
|
||||
"name": assistant_reply.tool_calls[0].function.name,
|
||||
"args": assistant_reply.tool_calls[0].function.arguments,
|
||||
}
|
||||
try:
|
||||
if not isinstance(assistant_reply_json, dict):
|
||||
raise InvalidAgentResponseError(
|
||||
f"The previous message sent was not a dictionary {assistant_reply_json}"
|
||||
)
|
||||
|
||||
if "command" not in assistant_reply_json:
|
||||
raise InvalidAgentResponseError("Missing 'command' object in JSON")
|
||||
|
||||
command = assistant_reply_json["command"]
|
||||
if not isinstance(command, dict):
|
||||
raise InvalidAgentResponseError("'command' object is not a dictionary")
|
||||
|
||||
if "name" not in command:
|
||||
raise InvalidAgentResponseError("Missing 'name' field in 'command' object")
|
||||
|
||||
command_name = command["name"]
|
||||
|
||||
# Use an empty dictionary if 'args' field is not present in 'command' object
|
||||
arguments = command.get("args", {})
|
||||
|
||||
return command_name, arguments
|
||||
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise InvalidAgentResponseError("Invalid JSON")
|
||||
|
||||
except Exception as e:
|
||||
raise InvalidAgentResponseError(str(e))
|
||||
parsed_response = OneShotAgentActionProposal.parse_obj(assistant_reply_dict)
|
||||
if self.config.use_functions_api:
|
||||
if not response.tool_calls:
|
||||
raise InvalidAgentResponseError("Assistant did not use a tool")
|
||||
parsed_response.use_tool = response.tool_calls[0].function
|
||||
return parsed_response
|
||||
|
||||
51
autogpts/autogpt/autogpt/agents/protocols.py
Normal file
51
autogpts/autogpt/autogpt/agents/protocols.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Iterator
|
||||
|
||||
from autogpt.agents.components import AgentComponent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.base import BaseAgentActionProposal
|
||||
from autogpt.core.resource.model_providers.schema import ChatMessage
|
||||
from autogpt.models.action_history import ActionResult
|
||||
from autogpt.models.command import Command
|
||||
|
||||
|
||||
class DirectiveProvider(AgentComponent):
|
||||
def get_constraints(self) -> Iterator[str]:
|
||||
return iter([])
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
return iter([])
|
||||
|
||||
def get_best_practices(self) -> Iterator[str]:
|
||||
return iter([])
|
||||
|
||||
|
||||
class CommandProvider(AgentComponent):
|
||||
@abstractmethod
|
||||
def get_commands(self) -> Iterator["Command"]:
|
||||
...
|
||||
|
||||
|
||||
class MessageProvider(AgentComponent):
|
||||
@abstractmethod
|
||||
def get_messages(self) -> Iterator["ChatMessage"]:
|
||||
...
|
||||
|
||||
|
||||
class AfterParse(AgentComponent):
|
||||
@abstractmethod
|
||||
def after_parse(self, result: "BaseAgentActionProposal") -> None:
|
||||
...
|
||||
|
||||
|
||||
class ExecutionFailure(AgentComponent):
|
||||
@abstractmethod
|
||||
def execution_failure(self, error: Exception) -> None:
|
||||
...
|
||||
|
||||
|
||||
class AfterExecute(AgentComponent):
|
||||
@abstractmethod
|
||||
def after_execute(self, result: "ActionResult") -> None:
|
||||
...
|
||||
@@ -1,108 +0,0 @@
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogpt.core.resource.model_providers.schema import CompletionModelFunction
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
logger = logging.getLogger("PromptScratchpad")
|
||||
|
||||
|
||||
class CallableCompletionModelFunction(CompletionModelFunction):
|
||||
method: Callable
|
||||
|
||||
|
||||
class PromptScratchpad(BaseModel):
|
||||
commands: dict[str, CallableCompletionModelFunction] = Field(default_factory=dict)
|
||||
resources: list[str] = Field(default_factory=list)
|
||||
constraints: list[str] = Field(default_factory=list)
|
||||
best_practices: list[str] = Field(default_factory=list)
|
||||
|
||||
def add_constraint(self, constraint: str) -> None:
|
||||
"""
|
||||
Add a constraint to the constraints list.
|
||||
|
||||
Params:
|
||||
constraint (str): The constraint to be added.
|
||||
"""
|
||||
if constraint not in self.constraints:
|
||||
self.constraints.append(constraint)
|
||||
|
||||
def add_command(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
params: dict[str, str | dict],
|
||||
function: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
Registers a command.
|
||||
|
||||
*Should only be used by plugins.* Native commands should be added
|
||||
directly to the CommandRegistry.
|
||||
|
||||
Params:
|
||||
name (str): The name of the command (e.g. `command_name`).
|
||||
description (str): The description of the command.
|
||||
params (dict, optional): A dictionary containing argument names and their
|
||||
types. Defaults to an empty dictionary.
|
||||
function (callable, optional): A callable function to be called when
|
||||
the command is executed. Defaults to None.
|
||||
"""
|
||||
for p, s in params.items():
|
||||
invalid = False
|
||||
if type(s) is str and s not in JSONSchema.Type._value2member_map_:
|
||||
invalid = True
|
||||
logger.warning(
|
||||
f"Cannot add command '{name}':"
|
||||
f" parameter '{p}' has invalid type '{s}'."
|
||||
f" Valid types are: {JSONSchema.Type._value2member_map_.keys()}"
|
||||
)
|
||||
elif isinstance(s, dict):
|
||||
try:
|
||||
JSONSchema.from_dict(s)
|
||||
except KeyError:
|
||||
invalid = True
|
||||
if invalid:
|
||||
return
|
||||
|
||||
command = CallableCompletionModelFunction(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters={
|
||||
name: JSONSchema(type=JSONSchema.Type._value2member_map_[spec])
|
||||
if type(spec) is str
|
||||
else JSONSchema.from_dict(spec)
|
||||
for name, spec in params.items()
|
||||
},
|
||||
method=function,
|
||||
)
|
||||
|
||||
if name in self.commands:
|
||||
if description == self.commands[name].description:
|
||||
return
|
||||
logger.warning(
|
||||
f"Replacing command {self.commands[name]} with conflicting {command}"
|
||||
)
|
||||
self.commands[name] = command
|
||||
|
||||
def add_resource(self, resource: str) -> None:
|
||||
"""
|
||||
Add a resource to the resources list.
|
||||
|
||||
Params:
|
||||
resource (str): The resource to be added.
|
||||
"""
|
||||
if resource not in self.resources:
|
||||
self.resources.append(resource)
|
||||
|
||||
def add_best_practice(self, best_practice: str) -> None:
|
||||
"""
|
||||
Add an item to the list of best practices.
|
||||
|
||||
Params:
|
||||
best_practice (str): The best practice item to be added.
|
||||
"""
|
||||
if best_practice not in self.best_practices:
|
||||
self.best_practices.append(best_practice)
|
||||
@@ -31,17 +31,13 @@ from sentry_sdk import set_user
|
||||
from autogpt.agent_factory.configurators import configure_agent_with_state
|
||||
from autogpt.agent_factory.generators import generate_agent_for_task
|
||||
from autogpt.agent_manager import AgentManager
|
||||
from autogpt.agents.utils.exceptions import AgentFinished
|
||||
from autogpt.app.utils import is_port_free
|
||||
from autogpt.commands.system import finish
|
||||
from autogpt.commands.user_interaction import ask_user
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIProvider
|
||||
from autogpt.core.resource.model_providers.schema import ModelProviderBudget
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider, ModelProviderBudget
|
||||
from autogpt.file_storage import FileStorage
|
||||
from autogpt.logs.utils import fmt_kwargs
|
||||
from autogpt.models.action_history import ActionErrorResult, ActionSuccessResult
|
||||
from autogpt.utils.exceptions import AgentFinished
|
||||
from autogpt.utils.utils import DEFAULT_ASK_COMMAND, DEFAULT_FINISH_COMMAND
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -149,7 +145,7 @@ class AgentProtocolServer:
|
||||
file_storage=self.file_storage,
|
||||
llm_provider=self._get_task_llm_provider(task),
|
||||
)
|
||||
await task_agent.save_state()
|
||||
await task_agent.file_manager.save_state()
|
||||
|
||||
return task
|
||||
|
||||
@@ -202,7 +198,7 @@ class AgentProtocolServer:
|
||||
# To prevent this from interfering with the agent's process, we ignore the input
|
||||
# of this first step request, and just generate the first step proposal.
|
||||
is_init_step = not bool(agent.event_history)
|
||||
execute_command, execute_command_args, execute_result = None, None, None
|
||||
last_proposal, tool_result = None, None
|
||||
execute_approved = False
|
||||
|
||||
# HACK: only for compatibility with AGBenchmark
|
||||
@@ -216,13 +212,11 @@ class AgentProtocolServer:
|
||||
and agent.event_history.current_episode
|
||||
and not agent.event_history.current_episode.result
|
||||
):
|
||||
execute_command = agent.event_history.current_episode.action.name
|
||||
execute_command_args = agent.event_history.current_episode.action.args
|
||||
last_proposal = agent.event_history.current_episode.action
|
||||
execute_approved = not user_input
|
||||
|
||||
logger.debug(
|
||||
f"Agent proposed command"
|
||||
f" {execute_command}({fmt_kwargs(execute_command_args)})."
|
||||
f"Agent proposed command {last_proposal.use_tool}."
|
||||
f" User input/feedback: {repr(user_input)}"
|
||||
)
|
||||
|
||||
@@ -230,22 +224,25 @@ class AgentProtocolServer:
|
||||
step = await self.db.create_step(
|
||||
task_id=task_id,
|
||||
input=step_request,
|
||||
is_last=execute_command == finish.__name__ and execute_approved,
|
||||
is_last=(
|
||||
last_proposal is not None
|
||||
and last_proposal.use_tool.name == DEFAULT_FINISH_COMMAND
|
||||
and execute_approved
|
||||
),
|
||||
)
|
||||
agent.llm_provider = self._get_task_llm_provider(task, step.step_id)
|
||||
|
||||
# Execute previously proposed action
|
||||
if execute_command:
|
||||
assert execute_command_args is not None
|
||||
agent.workspace.on_write_file = lambda path: self._on_agent_write_file(
|
||||
task=task, step=step, relative_path=path
|
||||
if last_proposal:
|
||||
agent.file_manager.workspace.on_write_file = (
|
||||
lambda path: self._on_agent_write_file(
|
||||
task=task, step=step, relative_path=path
|
||||
)
|
||||
)
|
||||
|
||||
if execute_command == ask_user.__name__: # HACK
|
||||
execute_result = ActionSuccessResult(outputs=user_input)
|
||||
agent.event_history.register_result(execute_result)
|
||||
elif not execute_command:
|
||||
execute_result = None
|
||||
if last_proposal.use_tool.name == DEFAULT_ASK_COMMAND:
|
||||
tool_result = ActionSuccessResult(outputs=user_input)
|
||||
agent.event_history.register_result(tool_result)
|
||||
elif execute_approved:
|
||||
step = await self.db.update_step(
|
||||
task_id=task_id,
|
||||
@@ -255,10 +252,7 @@ class AgentProtocolServer:
|
||||
|
||||
try:
|
||||
# Execute previously proposed action
|
||||
execute_result = await agent.execute(
|
||||
command_name=execute_command,
|
||||
command_args=execute_command_args,
|
||||
)
|
||||
tool_result = await agent.execute(last_proposal)
|
||||
except AgentFinished:
|
||||
additional_output = {}
|
||||
task_total_cost = agent.llm_provider.get_incurred_cost()
|
||||
@@ -272,23 +266,20 @@ class AgentProtocolServer:
|
||||
step = await self.db.update_step(
|
||||
task_id=task_id,
|
||||
step_id=step.step_id,
|
||||
output=execute_command_args["reason"],
|
||||
output=last_proposal.use_tool.arguments["reason"],
|
||||
additional_output=additional_output,
|
||||
)
|
||||
await agent.save_state()
|
||||
await agent.file_manager.save_state()
|
||||
return step
|
||||
else:
|
||||
assert user_input
|
||||
execute_result = await agent.execute(
|
||||
command_name="human_feedback", # HACK
|
||||
command_args={},
|
||||
user_input=user_input,
|
||||
)
|
||||
tool_result = await agent.do_not_execute(last_proposal, user_input)
|
||||
|
||||
# Propose next action
|
||||
try:
|
||||
next_command, next_command_args, raw_output = await agent.propose_action()
|
||||
logger.debug(f"AI output: {raw_output}")
|
||||
assistant_response = await agent.propose_action()
|
||||
next_tool_to_use = assistant_response.use_tool
|
||||
logger.debug(f"AI output: {assistant_response.thoughts}")
|
||||
except Exception as e:
|
||||
step = await self.db.update_step(
|
||||
task_id=task_id,
|
||||
@@ -301,44 +292,44 @@ class AgentProtocolServer:
|
||||
# Format step output
|
||||
output = (
|
||||
(
|
||||
f"`{execute_command}({fmt_kwargs(execute_command_args)})` returned:"
|
||||
+ ("\n\n" if "\n" in str(execute_result) else " ")
|
||||
+ f"{execute_result}\n\n"
|
||||
f"`{last_proposal.use_tool}` returned:"
|
||||
+ ("\n\n" if "\n" in str(tool_result) else " ")
|
||||
+ f"{tool_result}\n\n"
|
||||
)
|
||||
if execute_command_args and execute_command != ask_user.__name__
|
||||
if last_proposal and last_proposal.use_tool.name != DEFAULT_ASK_COMMAND
|
||||
else ""
|
||||
)
|
||||
output += f"{raw_output['thoughts']['speak']}\n\n"
|
||||
output += f"{assistant_response.thoughts.speak}\n\n"
|
||||
output += (
|
||||
f"Next Command: {next_command}({fmt_kwargs(next_command_args)})"
|
||||
if next_command != ask_user.__name__
|
||||
else next_command_args["question"]
|
||||
f"Next Command: {next_tool_to_use}"
|
||||
if next_tool_to_use.name != DEFAULT_ASK_COMMAND
|
||||
else next_tool_to_use.arguments["question"]
|
||||
)
|
||||
|
||||
additional_output = {
|
||||
**(
|
||||
{
|
||||
"last_action": {
|
||||
"name": execute_command,
|
||||
"args": execute_command_args,
|
||||
"name": last_proposal.use_tool.name,
|
||||
"args": last_proposal.use_tool.arguments,
|
||||
"result": (
|
||||
""
|
||||
if execute_result is None
|
||||
if tool_result is None
|
||||
else (
|
||||
orjson.loads(execute_result.json())
|
||||
if not isinstance(execute_result, ActionErrorResult)
|
||||
orjson.loads(tool_result.json())
|
||||
if not isinstance(tool_result, ActionErrorResult)
|
||||
else {
|
||||
"error": str(execute_result.error),
|
||||
"reason": execute_result.reason,
|
||||
"error": str(tool_result.error),
|
||||
"reason": tool_result.reason,
|
||||
}
|
||||
)
|
||||
),
|
||||
},
|
||||
}
|
||||
if not is_init_step
|
||||
if last_proposal and tool_result
|
||||
else {}
|
||||
),
|
||||
**raw_output,
|
||||
**assistant_response.dict(),
|
||||
}
|
||||
|
||||
task_cumulative_cost = agent.llm_provider.get_incurred_cost()
|
||||
@@ -357,7 +348,7 @@ class AgentProtocolServer:
|
||||
additional_output=additional_output,
|
||||
)
|
||||
|
||||
await agent.save_state()
|
||||
await agent.file_manager.save_state()
|
||||
return step
|
||||
|
||||
async def _on_agent_write_file(
|
||||
@@ -472,20 +463,18 @@ class AgentProtocolServer:
|
||||
if task.additional_input and (user_id := task.additional_input.get("user_id")):
|
||||
_extra_request_headers["AutoGPT-UserID"] = user_id
|
||||
|
||||
task_llm_provider = None
|
||||
if isinstance(self.llm_provider, OpenAIProvider):
|
||||
settings = self.llm_provider._settings.copy()
|
||||
settings.budget = task_llm_budget
|
||||
settings.configuration = task_llm_provider_config # type: ignore
|
||||
task_llm_provider = OpenAIProvider(
|
||||
settings=settings,
|
||||
logger=logger.getChild(f"Task-{task.task_id}_OpenAIProvider"),
|
||||
)
|
||||
settings = self.llm_provider._settings.copy()
|
||||
settings.budget = task_llm_budget
|
||||
settings.configuration = task_llm_provider_config
|
||||
task_llm_provider = self.llm_provider.__class__(
|
||||
settings=settings,
|
||||
logger=logger.getChild(
|
||||
f"Task-{task.task_id}_{self.llm_provider.__class__.__name__}"
|
||||
),
|
||||
)
|
||||
self._task_budgets[task.task_id] = task_llm_provider._budget # type: ignore
|
||||
|
||||
if task_llm_provider and task_llm_provider._budget:
|
||||
self._task_budgets[task.task_id] = task_llm_provider._budget
|
||||
|
||||
return task_llm_provider or self.llm_provider
|
||||
return task_llm_provider
|
||||
|
||||
|
||||
def task_agent_id(task_id: str | int) -> str:
|
||||
|
||||
@@ -8,12 +8,12 @@ from typing import Literal, Optional
|
||||
import click
|
||||
from colorama import Back, Fore, Style
|
||||
|
||||
from autogpt import utils
|
||||
from autogpt.config import Config
|
||||
from autogpt.config.config import GPT_3_MODEL, GPT_4_MODEL
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIModelName, OpenAIProvider
|
||||
from autogpt.core.resource.model_providers import ModelName, MultiProvider
|
||||
from autogpt.logs.helpers import request_user_double_check
|
||||
from autogpt.memory.vector import get_supported_memory_backends
|
||||
from autogpt.utils import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -150,11 +150,11 @@ async def apply_overrides_to_config(
|
||||
|
||||
|
||||
async def check_model(
|
||||
model_name: OpenAIModelName, model_type: Literal["smart_llm", "fast_llm"]
|
||||
) -> OpenAIModelName:
|
||||
model_name: ModelName, model_type: Literal["smart_llm", "fast_llm"]
|
||||
) -> ModelName:
|
||||
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
|
||||
openai = OpenAIProvider()
|
||||
models = await openai.get_available_models()
|
||||
multi_provider = MultiProvider()
|
||||
models = await multi_provider.get_available_models()
|
||||
|
||||
if any(model_name == m.name for m in models):
|
||||
return model_name
|
||||
|
||||
@@ -18,17 +18,16 @@ from forge.sdk.db import AgentDB
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.base import BaseAgentActionProposal
|
||||
|
||||
from autogpt.agent_factory.configurators import configure_agent_with_state, create_agent
|
||||
from autogpt.agent_factory.profile_generator import generate_agent_profile_for_task
|
||||
from autogpt.agent_manager import AgentManager
|
||||
from autogpt.agents import AgentThoughts, CommandArgs, CommandName
|
||||
from autogpt.agents.utils.exceptions import AgentTerminated, InvalidAgentResponseError
|
||||
from autogpt.agents.prompt_strategies.one_shot import AssistantThoughts
|
||||
from autogpt.commands.execute_code import (
|
||||
is_docker_available,
|
||||
we_are_running_in_a_docker_container,
|
||||
)
|
||||
from autogpt.commands.system import finish
|
||||
from autogpt.config import (
|
||||
AIDirectives,
|
||||
AIProfile,
|
||||
@@ -36,14 +35,15 @@ from autogpt.config import (
|
||||
ConfigBuilder,
|
||||
assert_config_has_openai_api_key,
|
||||
)
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIProvider
|
||||
from autogpt.core.resource.model_providers import MultiProvider
|
||||
from autogpt.core.runner.client_lib.utils import coroutine
|
||||
from autogpt.file_storage import FileStorageBackendName, get_storage
|
||||
from autogpt.logs.config import configure_chat_plugins, configure_logging
|
||||
from autogpt.logs.config import configure_logging
|
||||
from autogpt.logs.helpers import print_attribute, speak
|
||||
from autogpt.models.action_history import ActionInterruptedByHuman
|
||||
from autogpt.plugins import scan_plugins
|
||||
from scripts.install_plugin_deps import install_plugin_dependencies
|
||||
from autogpt.models.utils import ModelWithSummary
|
||||
from autogpt.utils.exceptions import AgentTerminated, InvalidAgentResponseError
|
||||
from autogpt.utils.utils import DEFAULT_FINISH_COMMAND
|
||||
|
||||
from .configurator import apply_overrides_to_config
|
||||
from .setup import apply_overrides_to_ai_settings, interactively_revise_ai_settings
|
||||
@@ -102,6 +102,7 @@ async def run_auto_gpt(
|
||||
level=log_level,
|
||||
log_format=log_format,
|
||||
log_file_format=log_file_format,
|
||||
config=config.logging,
|
||||
tts_config=config.tts_config,
|
||||
)
|
||||
|
||||
@@ -122,7 +123,7 @@ async def run_auto_gpt(
|
||||
skip_news=skip_news,
|
||||
)
|
||||
|
||||
llm_provider = _configure_openai_provider(config)
|
||||
llm_provider = _configure_llm_provider(config)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -165,12 +166,6 @@ async def run_auto_gpt(
|
||||
title_color=Fore.YELLOW,
|
||||
)
|
||||
|
||||
if install_plugin_deps:
|
||||
install_plugin_dependencies()
|
||||
|
||||
config.plugins = scan_plugins(config)
|
||||
configure_chat_plugins(config)
|
||||
|
||||
# Let user choose an existing agent to run
|
||||
agent_manager = AgentManager(file_storage)
|
||||
existing_agents = agent_manager.list_agents()
|
||||
@@ -190,7 +185,7 @@ async def run_auto_gpt(
|
||||
) <= len(existing_agents):
|
||||
load_existing_agent = existing_agents[int(load_existing_agent) - 1]
|
||||
|
||||
if load_existing_agent not in existing_agents:
|
||||
if load_existing_agent != "" and load_existing_agent not in existing_agents:
|
||||
logger.info(
|
||||
f"Unknown agent '{load_existing_agent}', "
|
||||
f"creating a new one instead.",
|
||||
@@ -234,12 +229,12 @@ async def run_auto_gpt(
|
||||
)
|
||||
|
||||
if (
|
||||
agent.event_history.current_episode
|
||||
and agent.event_history.current_episode.action.name == finish.__name__
|
||||
and not agent.event_history.current_episode.result
|
||||
(current_episode := agent.event_history.current_episode)
|
||||
and current_episode.action.use_tool.name == DEFAULT_FINISH_COMMAND
|
||||
and not current_episode.result
|
||||
):
|
||||
# Agent was resumed after `finish` -> rewrite result of `finish` action
|
||||
finish_reason = agent.event_history.current_episode.action.args["reason"]
|
||||
finish_reason = current_episode.action.use_tool.arguments["reason"]
|
||||
print(f"Agent previously self-terminated; reason: '{finish_reason}'")
|
||||
new_assignment = clean_input(
|
||||
config, "Please give a follow-up question or assignment:"
|
||||
@@ -327,11 +322,13 @@ async def run_auto_gpt(
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
if not agent.config.allow_fs_access:
|
||||
file_manager = agent.file_manager
|
||||
|
||||
if file_manager and not agent.config.allow_fs_access:
|
||||
logger.info(
|
||||
f"{Fore.YELLOW}"
|
||||
"NOTE: All files/directories created by this agent can be found "
|
||||
f"inside its workspace at:{Fore.RESET} {agent.workspace.root}",
|
||||
f"inside its workspace at:{Fore.RESET} {file_manager.workspace.root}",
|
||||
extra={"preserve_color": True},
|
||||
)
|
||||
|
||||
@@ -351,7 +348,9 @@ async def run_auto_gpt(
|
||||
" or enter a different ID to save to:",
|
||||
)
|
||||
# TODO: allow many-to-one relations of agents and workspaces
|
||||
await agent.save_state(save_as_id if not save_as_id.isspace() else None)
|
||||
await agent.file_manager.save_state(
|
||||
save_as_id.strip() if not save_as_id.isspace() else None
|
||||
)
|
||||
|
||||
|
||||
@coroutine
|
||||
@@ -384,6 +383,7 @@ async def run_auto_gpt_server(
|
||||
level=log_level,
|
||||
log_format=log_format,
|
||||
log_file_format=log_file_format,
|
||||
config=config.logging,
|
||||
tts_config=config.tts_config,
|
||||
)
|
||||
|
||||
@@ -399,12 +399,7 @@ async def run_auto_gpt_server(
|
||||
allow_downloads=allow_downloads,
|
||||
)
|
||||
|
||||
llm_provider = _configure_openai_provider(config)
|
||||
|
||||
if install_plugin_deps:
|
||||
install_plugin_dependencies()
|
||||
|
||||
config.plugins = scan_plugins(config)
|
||||
llm_provider = _configure_llm_provider(config)
|
||||
|
||||
# Set up & start server
|
||||
database = AgentDB(
|
||||
@@ -426,24 +421,12 @@ async def run_auto_gpt_server(
|
||||
)
|
||||
|
||||
|
||||
def _configure_openai_provider(config: Config) -> OpenAIProvider:
|
||||
"""Create a configured OpenAIProvider object.
|
||||
|
||||
Args:
|
||||
config: The program's configuration.
|
||||
|
||||
Returns:
|
||||
A configured OpenAIProvider object.
|
||||
"""
|
||||
if config.openai_credentials is None:
|
||||
raise RuntimeError("OpenAI key is not configured")
|
||||
|
||||
openai_settings = OpenAIProvider.default_settings.copy(deep=True)
|
||||
openai_settings.credentials = config.openai_credentials
|
||||
return OpenAIProvider(
|
||||
settings=openai_settings,
|
||||
logger=logging.getLogger("OpenAIProvider"),
|
||||
)
|
||||
def _configure_llm_provider(config: Config) -> MultiProvider:
|
||||
multi_provider = MultiProvider()
|
||||
for model in [config.smart_llm, config.fast_llm]:
|
||||
# Ensure model providers for configured LLMs are available
|
||||
multi_provider.get_model_provider(model)
|
||||
return multi_provider
|
||||
|
||||
|
||||
def _get_cycle_budget(continuous_mode: bool, continuous_limit: int) -> int | float:
|
||||
@@ -537,11 +520,7 @@ async def run_interaction_loop(
|
||||
# Have the agent determine the next action to take.
|
||||
with spinner:
|
||||
try:
|
||||
(
|
||||
command_name,
|
||||
command_args,
|
||||
assistant_reply_dict,
|
||||
) = await agent.propose_action()
|
||||
action_proposal = await agent.propose_action()
|
||||
except InvalidAgentResponseError as e:
|
||||
logger.warning(f"The agent's thoughts could not be parsed: {e}")
|
||||
consecutive_failures += 1
|
||||
@@ -564,9 +543,7 @@ async def run_interaction_loop(
|
||||
# Print the assistant's thoughts and the next command to the user.
|
||||
update_user(
|
||||
ai_profile,
|
||||
command_name,
|
||||
command_args,
|
||||
assistant_reply_dict,
|
||||
action_proposal,
|
||||
speak_mode=legacy_config.tts_config.speak_mode,
|
||||
)
|
||||
|
||||
@@ -575,12 +552,12 @@ async def run_interaction_loop(
|
||||
##################
|
||||
handle_stop_signal()
|
||||
if cycles_remaining == 1: # Last cycle
|
||||
user_feedback, user_input, new_cycles_remaining = await get_user_feedback(
|
||||
feedback_type, feedback, new_cycles_remaining = await get_user_feedback(
|
||||
legacy_config,
|
||||
ai_profile,
|
||||
)
|
||||
|
||||
if user_feedback == UserFeedback.AUTHORIZE:
|
||||
if feedback_type == UserFeedback.AUTHORIZE:
|
||||
if new_cycles_remaining is not None:
|
||||
# Case 1: User is altering the cycle budget.
|
||||
if cycle_budget > 1:
|
||||
@@ -604,13 +581,13 @@ async def run_interaction_loop(
|
||||
"-=-=-=-=-=-=-= COMMAND AUTHORISED BY USER -=-=-=-=-=-=-=",
|
||||
extra={"color": Fore.MAGENTA},
|
||||
)
|
||||
elif user_feedback == UserFeedback.EXIT:
|
||||
elif feedback_type == UserFeedback.EXIT:
|
||||
logger.warning("Exiting...")
|
||||
exit()
|
||||
else: # user_feedback == UserFeedback.TEXT
|
||||
command_name = "human_feedback"
|
||||
pass
|
||||
else:
|
||||
user_input = ""
|
||||
feedback = ""
|
||||
# First log new-line so user can differentiate sections better in console
|
||||
print()
|
||||
if cycles_remaining != math.inf:
|
||||
@@ -625,33 +602,31 @@ async def run_interaction_loop(
|
||||
# Decrement the cycle counter first to reduce the likelihood of a SIGINT
|
||||
# happening during command execution, setting the cycles remaining to 1,
|
||||
# and then having the decrement set it to 0, exiting the application.
|
||||
if command_name != "human_feedback":
|
||||
if not feedback:
|
||||
cycles_remaining -= 1
|
||||
|
||||
if not command_name:
|
||||
if not action_proposal.use_tool:
|
||||
continue
|
||||
|
||||
handle_stop_signal()
|
||||
|
||||
if command_name:
|
||||
result = await agent.execute(command_name, command_args, user_input)
|
||||
if not feedback:
|
||||
result = await agent.execute(action_proposal)
|
||||
else:
|
||||
result = await agent.do_not_execute(action_proposal, feedback)
|
||||
|
||||
if result.status == "success":
|
||||
logger.info(
|
||||
result, extra={"title": "SYSTEM:", "title_color": Fore.YELLOW}
|
||||
)
|
||||
elif result.status == "error":
|
||||
logger.warning(
|
||||
f"Command {command_name} returned an error: "
|
||||
f"{result.error or result.reason}"
|
||||
)
|
||||
if result.status == "success":
|
||||
logger.info(result, extra={"title": "SYSTEM:", "title_color": Fore.YELLOW})
|
||||
elif result.status == "error":
|
||||
logger.warning(
|
||||
f"Command {action_proposal.use_tool.name} returned an error: "
|
||||
f"{result.error or result.reason}"
|
||||
)
|
||||
|
||||
|
||||
def update_user(
|
||||
ai_profile: AIProfile,
|
||||
command_name: CommandName,
|
||||
command_args: CommandArgs,
|
||||
assistant_reply_dict: AgentThoughts,
|
||||
action_proposal: "BaseAgentActionProposal",
|
||||
speak_mode: bool = False,
|
||||
) -> None:
|
||||
"""Prints the assistant's thoughts and the next command to the user.
|
||||
@@ -667,18 +642,19 @@ def update_user(
|
||||
|
||||
print_assistant_thoughts(
|
||||
ai_name=ai_profile.ai_name,
|
||||
assistant_reply_json_valid=assistant_reply_dict,
|
||||
thoughts=action_proposal.thoughts,
|
||||
speak_mode=speak_mode,
|
||||
)
|
||||
|
||||
if speak_mode:
|
||||
speak(f"I want to execute {command_name}")
|
||||
speak(f"I want to execute {action_proposal.use_tool.name}")
|
||||
|
||||
# First log new-line so user can differentiate sections better in console
|
||||
print()
|
||||
safe_tool_name = remove_ansi_escape(action_proposal.use_tool.name)
|
||||
logger.info(
|
||||
f"COMMAND = {Fore.CYAN}{remove_ansi_escape(command_name)}{Style.RESET_ALL} "
|
||||
f"ARGUMENTS = {Fore.CYAN}{command_args}{Style.RESET_ALL}",
|
||||
f"COMMAND = {Fore.CYAN}{safe_tool_name}{Style.RESET_ALL} "
|
||||
f"ARGUMENTS = {Fore.CYAN}{action_proposal.use_tool.arguments}{Style.RESET_ALL}",
|
||||
extra={
|
||||
"title": "NEXT ACTION:",
|
||||
"title_color": Fore.CYAN,
|
||||
@@ -719,12 +695,7 @@ async def get_user_feedback(
|
||||
|
||||
while user_feedback is None:
|
||||
# Get input from user
|
||||
if config.chat_messages_enabled:
|
||||
console_input = clean_input(config, "Waiting for your response...")
|
||||
else:
|
||||
console_input = clean_input(
|
||||
config, Fore.MAGENTA + "Input:" + Style.RESET_ALL
|
||||
)
|
||||
console_input = clean_input(config, Fore.MAGENTA + "Input:" + Style.RESET_ALL)
|
||||
|
||||
# Parse user input
|
||||
if console_input.lower().strip() == config.authorise_key:
|
||||
@@ -752,56 +723,59 @@ async def get_user_feedback(
|
||||
|
||||
def print_assistant_thoughts(
|
||||
ai_name: str,
|
||||
assistant_reply_json_valid: dict,
|
||||
thoughts: str | ModelWithSummary | AssistantThoughts,
|
||||
speak_mode: bool = False,
|
||||
) -> None:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
assistant_thoughts_reasoning = None
|
||||
assistant_thoughts_plan = None
|
||||
assistant_thoughts_speak = None
|
||||
assistant_thoughts_criticism = None
|
||||
|
||||
assistant_thoughts = assistant_reply_json_valid.get("thoughts", {})
|
||||
assistant_thoughts_text = remove_ansi_escape(assistant_thoughts.get("text", ""))
|
||||
if assistant_thoughts:
|
||||
assistant_thoughts_reasoning = remove_ansi_escape(
|
||||
assistant_thoughts.get("reasoning", "")
|
||||
)
|
||||
assistant_thoughts_plan = remove_ansi_escape(assistant_thoughts.get("plan", ""))
|
||||
assistant_thoughts_criticism = remove_ansi_escape(
|
||||
assistant_thoughts.get("self_criticism", "")
|
||||
)
|
||||
assistant_thoughts_speak = remove_ansi_escape(
|
||||
assistant_thoughts.get("speak", "")
|
||||
)
|
||||
print_attribute(
|
||||
f"{ai_name.upper()} THOUGHTS", assistant_thoughts_text, title_color=Fore.YELLOW
|
||||
thoughts_text = remove_ansi_escape(
|
||||
thoughts.text
|
||||
if isinstance(thoughts, AssistantThoughts)
|
||||
else thoughts.summary()
|
||||
if isinstance(thoughts, ModelWithSummary)
|
||||
else thoughts
|
||||
)
|
||||
print_attribute("REASONING", assistant_thoughts_reasoning, title_color=Fore.YELLOW)
|
||||
if assistant_thoughts_plan:
|
||||
print_attribute("PLAN", "", title_color=Fore.YELLOW)
|
||||
# If it's a list, join it into a string
|
||||
if isinstance(assistant_thoughts_plan, list):
|
||||
assistant_thoughts_plan = "\n".join(assistant_thoughts_plan)
|
||||
elif isinstance(assistant_thoughts_plan, dict):
|
||||
assistant_thoughts_plan = str(assistant_thoughts_plan)
|
||||
|
||||
# Split the input_string using the newline character and dashes
|
||||
lines = assistant_thoughts_plan.split("\n")
|
||||
for line in lines:
|
||||
line = line.lstrip("- ")
|
||||
logger.info(line.strip(), extra={"title": "- ", "title_color": Fore.GREEN})
|
||||
print_attribute(
|
||||
"CRITICISM", f"{assistant_thoughts_criticism}", title_color=Fore.YELLOW
|
||||
f"{ai_name.upper()} THOUGHTS", thoughts_text, title_color=Fore.YELLOW
|
||||
)
|
||||
|
||||
# Speak the assistant's thoughts
|
||||
if assistant_thoughts_speak:
|
||||
if speak_mode:
|
||||
speak(assistant_thoughts_speak)
|
||||
else:
|
||||
print_attribute("SPEAK", assistant_thoughts_speak, title_color=Fore.YELLOW)
|
||||
if isinstance(thoughts, AssistantThoughts):
|
||||
print_attribute(
|
||||
"REASONING", remove_ansi_escape(thoughts.reasoning), title_color=Fore.YELLOW
|
||||
)
|
||||
if assistant_thoughts_plan := remove_ansi_escape(
|
||||
"\n".join(f"- {p}" for p in thoughts.plan)
|
||||
):
|
||||
print_attribute("PLAN", "", title_color=Fore.YELLOW)
|
||||
# If it's a list, join it into a string
|
||||
if isinstance(assistant_thoughts_plan, list):
|
||||
assistant_thoughts_plan = "\n".join(assistant_thoughts_plan)
|
||||
elif isinstance(assistant_thoughts_plan, dict):
|
||||
assistant_thoughts_plan = str(assistant_thoughts_plan)
|
||||
|
||||
# Split the input_string using the newline character and dashes
|
||||
lines = assistant_thoughts_plan.split("\n")
|
||||
for line in lines:
|
||||
line = line.lstrip("- ")
|
||||
logger.info(
|
||||
line.strip(), extra={"title": "- ", "title_color": Fore.GREEN}
|
||||
)
|
||||
print_attribute(
|
||||
"CRITICISM",
|
||||
remove_ansi_escape(thoughts.self_criticism),
|
||||
title_color=Fore.YELLOW,
|
||||
)
|
||||
|
||||
# Speak the assistant's thoughts
|
||||
if assistant_thoughts_speak := remove_ansi_escape(thoughts.speak):
|
||||
if speak_mode:
|
||||
speak(assistant_thoughts_speak)
|
||||
else:
|
||||
print_attribute(
|
||||
"SPEAK", assistant_thoughts_speak, title_color=Fore.YELLOW
|
||||
)
|
||||
else:
|
||||
speak(thoughts_text)
|
||||
|
||||
|
||||
def remove_ansi_escape(s: str) -> str:
|
||||
|
||||
@@ -20,34 +20,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def clean_input(config: "Config", prompt: str = ""):
|
||||
try:
|
||||
if config.chat_messages_enabled:
|
||||
for plugin in config.plugins:
|
||||
if not hasattr(plugin, "can_handle_user_input"):
|
||||
continue
|
||||
if not plugin.can_handle_user_input(user_input=prompt):
|
||||
continue
|
||||
plugin_response = plugin.user_input(user_input=prompt)
|
||||
if not plugin_response:
|
||||
continue
|
||||
if plugin_response.lower() in [
|
||||
"yes",
|
||||
"yeah",
|
||||
"y",
|
||||
"ok",
|
||||
"okay",
|
||||
"sure",
|
||||
"alright",
|
||||
]:
|
||||
return config.authorise_key
|
||||
elif plugin_response.lower() in [
|
||||
"no",
|
||||
"nope",
|
||||
"n",
|
||||
"negative",
|
||||
]:
|
||||
return config.exit_key
|
||||
return plugin_response
|
||||
|
||||
# ask for input, default when just pressing Enter is y
|
||||
logger.debug("Asking user via keyboard...")
|
||||
|
||||
@@ -215,7 +187,7 @@ def print_motd(config: "Config", logger: logging.Logger):
|
||||
},
|
||||
msg=motd_line,
|
||||
)
|
||||
if is_new_motd and not config.chat_messages_enabled:
|
||||
if is_new_motd:
|
||||
input(
|
||||
Fore.MAGENTA
|
||||
+ Style.BRIGHT
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, ParamSpec, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.base import BaseAgent
|
||||
from autogpt.config import Config
|
||||
import re
|
||||
from typing import Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.models.command import Command, CommandOutput, CommandParameter
|
||||
@@ -19,19 +12,35 @@ CO = TypeVar("CO", bound=CommandOutput)
|
||||
|
||||
|
||||
def command(
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict[str, JSONSchema],
|
||||
enabled: Literal[True] | Callable[[Config], bool] = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
aliases: list[str] = [],
|
||||
available: bool | Callable[[BaseAgent], bool] = True,
|
||||
) -> Callable[[Callable[P, CO]], Callable[P, CO]]:
|
||||
names: list[str] = [],
|
||||
description: Optional[str] = None,
|
||||
parameters: dict[str, JSONSchema] = {},
|
||||
) -> Callable[[Callable[P, CommandOutput]], Command]:
|
||||
"""
|
||||
The command decorator is used to create Command objects from ordinary functions.
|
||||
The command decorator is used to make a Command from a function.
|
||||
|
||||
Args:
|
||||
names (list[str]): The names of the command.
|
||||
If not provided, the function name will be used.
|
||||
description (str): A brief description of what the command does.
|
||||
If not provided, the docstring until double line break will be used
|
||||
(or entire docstring if no double line break is found)
|
||||
parameters (dict[str, JSONSchema]): The parameters of the function
|
||||
that the command executes.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, CO]) -> Callable[P, CO]:
|
||||
def decorator(func: Callable[P, CO]) -> Command:
|
||||
doc = func.__doc__ or ""
|
||||
# If names is not provided, use the function name
|
||||
command_names = names or [func.__name__]
|
||||
# If description is not provided, use the first part of the docstring
|
||||
if not (command_description := description):
|
||||
if not func.__doc__:
|
||||
raise ValueError("Description is required if function has no docstring")
|
||||
# Return the part of the docstring before double line break or everything
|
||||
command_description = re.sub(r"\s+", " ", doc.split("\n\n")[0].strip())
|
||||
|
||||
# Parameters
|
||||
typed_parameters = [
|
||||
CommandParameter(
|
||||
name=param_name,
|
||||
@@ -39,32 +48,15 @@ def command(
|
||||
)
|
||||
for param_name, spec in parameters.items()
|
||||
]
|
||||
cmd = Command(
|
||||
name=name,
|
||||
description=description,
|
||||
|
||||
# Wrap func with Command
|
||||
command = Command(
|
||||
names=command_names,
|
||||
description=command_description,
|
||||
method=func,
|
||||
parameters=typed_parameters,
|
||||
enabled=enabled,
|
||||
disabled_reason=disabled_reason,
|
||||
aliases=aliases,
|
||||
available=available,
|
||||
)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
else:
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
setattr(wrapper, "command", cmd)
|
||||
setattr(wrapper, AUTO_GPT_COMMAND_IDENTIFIER, True)
|
||||
|
||||
return wrapper
|
||||
return command
|
||||
|
||||
return decorator
|
||||
|
||||
128
autogpts/autogpt/autogpt/commands/README.md
Normal file
128
autogpts/autogpt/autogpt/commands/README.md
Normal file
@@ -0,0 +1,128 @@
|
||||
# 🧩 Components
|
||||
|
||||
Components are the building blocks of [🤖 Agents](./agents.md). They are classes inheriting `AgentComponent` or implementing one or more [⚙️ Protocols](./protocols.md) that give agent additional abilities or processing.
|
||||
|
||||
Components can be used to implement various functionalities like providing messages to the prompt, executing code, or interacting with external services.
|
||||
They can be enabled or disabled, ordered, and can rely on each other.
|
||||
|
||||
Components assigned in the agent's `__init__` via `self` are automatically detected upon the agent's instantiation.
|
||||
For example inside `__init__`: `self.my_component = MyComponent()`.
|
||||
You can use any valid Python variable name, what matters for the component to be detected is its type (`AgentComponent` or any protocol inheriting from it).
|
||||
|
||||
Visit [Built-in Components](./built-in-components.md) to see what components are available out of the box.
|
||||
|
||||
```py
|
||||
from autogpt.agents import Agent
|
||||
from autogpt.agents.components import AgentComponent
|
||||
|
||||
class HelloComponent(AgentComponent):
|
||||
pass
|
||||
|
||||
class SomeComponent(AgentComponent):
|
||||
def __init__(self, hello_component: HelloComponent):
|
||||
self.hello_component = hello_component
|
||||
|
||||
class MyAgent(Agent):
|
||||
def __init__(self):
|
||||
# These components will be automatically discovered and used
|
||||
self.hello_component = HelloComponent()
|
||||
# We pass HelloComponent to SomeComponent
|
||||
self.some_component = SomeComponent(self.hello_component)
|
||||
```
|
||||
|
||||
## Ordering components
|
||||
|
||||
The execution order of components is important because the latter ones may depend on the results of the former ones.
|
||||
|
||||
### Implicit order
|
||||
|
||||
Components can be ordered implicitly by the agent; each component can set `run_after` list to specify which components should run before it. This is useful when components rely on each other or need to be executed in a specific order. Otherwise, the order of components is alphabetical.
|
||||
|
||||
```py
|
||||
# This component will run after HelloComponent
|
||||
class CalculatorComponent(AgentComponent):
|
||||
run_after = [HelloComponent]
|
||||
```
|
||||
|
||||
### Explicit order
|
||||
|
||||
Sometimes it may be easier to order components explicitly by setting `self.components` list in the agent's `__init__` method. This way you can also ensure there's no circular dependencies and `run_after` is ignored.
|
||||
|
||||
!!! warning
|
||||
Be sure to include all components - by setting `self.components` list, you're overriding the default behavior of discovering components automatically. Since it's usually not intended agent will inform you in the terminal if some components were skipped.
|
||||
|
||||
```py
|
||||
class MyAgent(Agent):
|
||||
def __init__(self):
|
||||
self.hello_component = HelloComponent()
|
||||
self.calculator_component = CalculatorComponent(self.hello_component)
|
||||
# Explicitly set components list
|
||||
self.components = [self.hello_component, self.calculator_component]
|
||||
```
|
||||
|
||||
## Disabling components
|
||||
|
||||
You can control which components are enabled by setting their `_enabled` attribute.
|
||||
Either provide a `bool` value or a `Callable[[], bool]`, will be checked each time
|
||||
the component is about to be executed. This way you can dynamically enable or disable
|
||||
components based on some conditions.
|
||||
You can also provide a reason for disabling the component by setting `_disabled_reason`.
|
||||
The reason will be visible in the debug information.
|
||||
|
||||
```py
|
||||
class DisabledComponent(MessageProvider):
|
||||
def __init__(self):
|
||||
# Disable this component
|
||||
self._enabled = False
|
||||
self._disabled_reason = "This component is disabled because of reasons."
|
||||
|
||||
# Or disable based on some condition, either statically...:
|
||||
self._enabled = self.some_property is not None
|
||||
# ... or dynamically:
|
||||
self._enabled = lambda: self.some_property is not None
|
||||
|
||||
# This method will never be called
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
yield ChatMessage.user("This message won't be seen!")
|
||||
|
||||
def some_condition(self) -> bool:
|
||||
return False
|
||||
```
|
||||
|
||||
If you don't want the component at all, you can just remove it from the agent's `__init__` method. If you want to remove components you inherit from the parent class you can set the relevant attribute to `None`:
|
||||
|
||||
!!! Warning
|
||||
Be careful when removing components that are required by other components. This may lead to errors and unexpected behavior.
|
||||
|
||||
```py
|
||||
class MyAgent(Agent):
|
||||
def __init__(self):
|
||||
super().__init__(...)
|
||||
# Disable WatchdogComponent that is in the parent class
|
||||
self.watchdog = None
|
||||
|
||||
```
|
||||
|
||||
## Exceptions
|
||||
|
||||
Custom errors are provided which can be used to control the execution flow in case something went wrong. All those errors can be raised in protocol methods and will be caught by the agent.
|
||||
By default agent will retry three times and then re-raise an exception if it's still not resolved. All passed arguments are automatically handled and the values are reverted when needed.
|
||||
All errors accept an optional `str` message. There are following errors ordered by increasing broadness:
|
||||
|
||||
1. `ComponentEndpointError`: A single endpoint method failed to execute. Agent will retry the execution of this endpoint on the component.
|
||||
2. `EndpointPipelineError`: A pipeline failed to execute. Agent will retry the execution of the endpoint for all components.
|
||||
3. `ComponentSystemError`: Multiple pipelines failed.
|
||||
|
||||
**Example**
|
||||
|
||||
```py
|
||||
from autogpt.agents.components import ComponentEndpointError
|
||||
from autogpt.agents.protocols import MessageProvider
|
||||
|
||||
# Example of raising an error
|
||||
class MyComponent(MessageProvider):
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
# This will cause the component to always fail
|
||||
# and retry 3 times before re-raising the exception
|
||||
raise ComponentEndpointError("Endpoint error!")
|
||||
```
|
||||
@@ -1,9 +0,0 @@
|
||||
COMMAND_CATEGORIES = [
|
||||
"autogpt.commands.execute_code",
|
||||
"autogpt.commands.file_operations",
|
||||
"autogpt.commands.user_interaction",
|
||||
"autogpt.commands.web_search",
|
||||
"autogpt.commands.web_selenium",
|
||||
"autogpt.commands.system",
|
||||
"autogpt.commands.image_gen",
|
||||
]
|
||||
@@ -1,82 +0,0 @@
|
||||
import functools
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def sanitize_path_arg(
|
||||
arg_name: str, make_relative: bool = False
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
"""Sanitizes the specified path (str | Path) argument, resolving it to a Path"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
# Get position of path parameter, in case it is passed as a positional argument
|
||||
try:
|
||||
arg_index = list(func.__annotations__.keys()).index(arg_name)
|
||||
except ValueError:
|
||||
raise TypeError(
|
||||
f"Sanitized parameter '{arg_name}' absent or not annotated"
|
||||
f" on function '{func.__name__}'"
|
||||
)
|
||||
|
||||
# Get position of agent parameter, in case it is passed as a positional argument
|
||||
try:
|
||||
agent_arg_index = list(func.__annotations__.keys()).index("agent")
|
||||
except ValueError:
|
||||
raise TypeError(
|
||||
f"Parameter 'agent' absent or not annotated"
|
||||
f" on function '{func.__name__}'"
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
logger.debug(f"Sanitizing arg '{arg_name}' on function '{func.__name__}'")
|
||||
|
||||
# Get Agent from the called function's arguments
|
||||
agent = kwargs.get(
|
||||
"agent", len(args) > agent_arg_index and args[agent_arg_index]
|
||||
)
|
||||
if not isinstance(agent, Agent):
|
||||
raise RuntimeError("Could not get Agent from decorated command's args")
|
||||
|
||||
# Sanitize the specified path argument, if one is given
|
||||
given_path: str | Path | None = kwargs.get(
|
||||
arg_name, len(args) > arg_index and args[arg_index] or None
|
||||
)
|
||||
if given_path:
|
||||
if type(given_path) is str:
|
||||
# Fix workspace path from output in docker environment
|
||||
given_path = re.sub(r"^\/workspace", ".", given_path)
|
||||
|
||||
if given_path in {"", "/", "."}:
|
||||
sanitized_path = agent.workspace.root
|
||||
else:
|
||||
sanitized_path = agent.workspace.get_path(given_path)
|
||||
|
||||
# Make path relative if possible
|
||||
if make_relative and sanitized_path.is_relative_to(
|
||||
agent.workspace.root
|
||||
):
|
||||
sanitized_path = sanitized_path.relative_to(agent.workspace.root)
|
||||
|
||||
if arg_name in kwargs:
|
||||
kwargs[arg_name] = sanitized_path
|
||||
else:
|
||||
# args is an immutable tuple; must be converted to a list to update
|
||||
arg_list = list(args)
|
||||
arg_list[arg_index] = sanitized_path
|
||||
args = tuple(arg_list)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -1,32 +1,28 @@
|
||||
"""Commands to execute code"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Iterator
|
||||
|
||||
import docker
|
||||
from docker.errors import DockerException, ImageNotFound, NotFound
|
||||
from docker.models.containers import Container as DockerContainer
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import (
|
||||
from autogpt.agents.base import BaseAgentSettings
|
||||
from autogpt.agents.protocols import CommandProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.utils.exceptions import (
|
||||
CodeExecutionError,
|
||||
CommandExecutionError,
|
||||
InvalidArgumentError,
|
||||
OperationNotAllowedError,
|
||||
)
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
|
||||
COMMAND_CATEGORY = "execute_code"
|
||||
COMMAND_CATEGORY_TITLE = "Execute Code"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,331 +53,344 @@ def is_docker_available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@command(
|
||||
"execute_python_code",
|
||||
"Executes the given Python code inside a single-use Docker container"
|
||||
" with access to your workspace folder",
|
||||
{
|
||||
"code": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The Python code to run",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
disabled_reason="To execute python code agent "
|
||||
"must be running in a Docker container or "
|
||||
"Docker must be available on the system.",
|
||||
available=we_are_running_in_a_docker_container() or is_docker_available(),
|
||||
)
|
||||
def execute_python_code(code: str, agent: Agent) -> str:
|
||||
"""
|
||||
Create and execute a Python file in a Docker container and return the STDOUT of the
|
||||
executed code.
|
||||
class CodeExecutorComponent(CommandProvider):
|
||||
"""Provides commands to execute Python code and shell commands."""
|
||||
|
||||
If the code generates any data that needs to be captured, use a print statement.
|
||||
def __init__(
|
||||
self, workspace: FileStorage, state: BaseAgentSettings, config: Config
|
||||
):
|
||||
self.workspace = workspace
|
||||
self.state = state
|
||||
self.legacy_config = config
|
||||
|
||||
Args:
|
||||
code (str): The Python code to run.
|
||||
agent (Agent): The Agent executing the command.
|
||||
if not we_are_running_in_a_docker_container() and not is_docker_available():
|
||||
logger.info(
|
||||
"Docker is not available or does not support Linux containers. "
|
||||
"The code execution commands will not be available."
|
||||
)
|
||||
|
||||
Returns:
|
||||
str: The STDOUT captured from the code when it ran.
|
||||
"""
|
||||
if not self.legacy_config.execute_local_commands:
|
||||
logger.info(
|
||||
"Local shell commands are disabled. To enable them,"
|
||||
" set EXECUTE_LOCAL_COMMANDS to 'True' in your config file."
|
||||
)
|
||||
|
||||
tmp_code_file = NamedTemporaryFile(
|
||||
"w", dir=agent.workspace.root, suffix=".py", encoding="utf-8"
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
if we_are_running_in_a_docker_container() or is_docker_available():
|
||||
yield self.execute_python_code
|
||||
yield self.execute_python_file
|
||||
|
||||
if self.legacy_config.execute_local_commands:
|
||||
yield self.execute_shell
|
||||
yield self.execute_shell_popen
|
||||
|
||||
@command(
|
||||
["execute_python_code"],
|
||||
"Executes the given Python code inside a single-use Docker container"
|
||||
" with access to your workspace folder",
|
||||
{
|
||||
"code": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The Python code to run",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
tmp_code_file.write(code)
|
||||
tmp_code_file.flush()
|
||||
def execute_python_code(self, code: str) -> str:
|
||||
"""
|
||||
Create and execute a Python file in a Docker container
|
||||
and return the STDOUT of the executed code.
|
||||
|
||||
try:
|
||||
return execute_python_file(tmp_code_file.name, agent) # type: ignore
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(*e.args)
|
||||
finally:
|
||||
tmp_code_file.close()
|
||||
If the code generates any data that needs to be captured,
|
||||
use a print statement.
|
||||
|
||||
Args:
|
||||
code (str): The Python code to run.
|
||||
agent (Agent): The Agent executing the command.
|
||||
|
||||
@command(
|
||||
"execute_python_file",
|
||||
"Execute an existing Python file inside a single-use Docker container"
|
||||
" with access to your workspace folder",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The name of the file to execute",
|
||||
required=True,
|
||||
),
|
||||
"args": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
description="The (command line) arguments to pass to the script",
|
||||
required=False,
|
||||
items=JSONSchema(type=JSONSchema.Type.STRING),
|
||||
),
|
||||
},
|
||||
disabled_reason="To execute python code agent "
|
||||
"must be running in a Docker container or "
|
||||
"Docker must be available on the system.",
|
||||
available=we_are_running_in_a_docker_container() or is_docker_available(),
|
||||
)
|
||||
@sanitize_path_arg("filename")
|
||||
def execute_python_file(
|
||||
filename: Path, agent: Agent, args: list[str] | str = []
|
||||
) -> str:
|
||||
"""Execute a Python file in a Docker container and return the output
|
||||
Returns:
|
||||
str: The STDOUT captured from the code when it ran.
|
||||
"""
|
||||
|
||||
Args:
|
||||
filename (Path): The name of the file to execute
|
||||
args (list, optional): The arguments with which to run the python script
|
||||
|
||||
Returns:
|
||||
str: The output of the file
|
||||
"""
|
||||
logger.info(
|
||||
f"Executing python file '{filename}' "
|
||||
f"in working directory '{agent.workspace.root}'"
|
||||
)
|
||||
|
||||
if isinstance(args, str):
|
||||
args = args.split() # Convert space-separated string to a list
|
||||
|
||||
if not str(filename).endswith(".py"):
|
||||
raise InvalidArgumentError("Invalid file type. Only .py files are allowed.")
|
||||
|
||||
file_path = filename
|
||||
if not file_path.is_file():
|
||||
# Mimic the response that you get from the command line to make it
|
||||
# intuitively understandable for the LLM
|
||||
raise FileNotFoundError(
|
||||
f"python: can't open file '{filename}': [Errno 2] No such file or directory"
|
||||
tmp_code_file = NamedTemporaryFile(
|
||||
"w", dir=self.workspace.root, suffix=".py", encoding="utf-8"
|
||||
)
|
||||
tmp_code_file.write(code)
|
||||
tmp_code_file.flush()
|
||||
|
||||
if we_are_running_in_a_docker_container():
|
||||
logger.debug(
|
||||
"AutoGPT is running in a Docker container; "
|
||||
f"executing {file_path} directly..."
|
||||
)
|
||||
result = subprocess.run(
|
||||
["python", "-B", str(file_path)] + args,
|
||||
capture_output=True,
|
||||
encoding="utf8",
|
||||
cwd=str(agent.workspace.root),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
else:
|
||||
raise CodeExecutionError(result.stderr)
|
||||
|
||||
logger.debug("AutoGPT is not running in a Docker container")
|
||||
try:
|
||||
assert agent.state.agent_id, "Need Agent ID to attach Docker container"
|
||||
|
||||
client = docker.from_env()
|
||||
# You can replace this with the desired Python image/version
|
||||
# You can find available Python images on Docker Hub:
|
||||
# https://hub.docker.com/_/python
|
||||
image_name = "python:3-alpine"
|
||||
container_is_fresh = False
|
||||
container_name = f"{agent.state.agent_id}_sandbox"
|
||||
try:
|
||||
container: DockerContainer = client.containers.get(
|
||||
container_name
|
||||
) # type: ignore
|
||||
except NotFound:
|
||||
try:
|
||||
client.images.get(image_name)
|
||||
logger.debug(f"Image '{image_name}' found locally")
|
||||
except ImageNotFound:
|
||||
logger.info(
|
||||
f"Image '{image_name}' not found locally,"
|
||||
" pulling from Docker Hub..."
|
||||
)
|
||||
# Use the low-level API to stream the pull response
|
||||
low_level_client = docker.APIClient()
|
||||
for line in low_level_client.pull(image_name, stream=True, decode=True):
|
||||
# Print the status and progress, if available
|
||||
status = line.get("status")
|
||||
progress = line.get("progress")
|
||||
if status and progress:
|
||||
logger.info(f"{status}: {progress}")
|
||||
elif status:
|
||||
logger.info(status)
|
||||
return self.execute_python_file(tmp_code_file.name)
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(*e.args)
|
||||
finally:
|
||||
tmp_code_file.close()
|
||||
|
||||
logger.debug(f"Creating new {image_name} container...")
|
||||
container: DockerContainer = client.containers.run(
|
||||
image_name,
|
||||
["sleep", "60"], # Max 60 seconds to prevent permanent hangs
|
||||
volumes={
|
||||
str(agent.workspace.root): {
|
||||
"bind": "/workspace",
|
||||
"mode": "rw",
|
||||
}
|
||||
},
|
||||
working_dir="/workspace",
|
||||
@command(
|
||||
["execute_python_file"],
|
||||
"Execute an existing Python file inside a single-use Docker container"
|
||||
" with access to your workspace folder",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The name of the file to execute",
|
||||
required=True,
|
||||
),
|
||||
"args": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
description="The (command line) arguments to pass to the script",
|
||||
required=False,
|
||||
items=JSONSchema(type=JSONSchema.Type.STRING),
|
||||
),
|
||||
},
|
||||
)
|
||||
def execute_python_file(self, filename: str, args: list[str] | str = []) -> str:
|
||||
"""Execute a Python file in a Docker container and return the output
|
||||
|
||||
Args:
|
||||
filename (Path): The name of the file to execute
|
||||
args (list, optional): The arguments with which to run the python script
|
||||
|
||||
Returns:
|
||||
str: The output of the file
|
||||
"""
|
||||
logger.info(
|
||||
f"Executing python file '{filename}' "
|
||||
f"in working directory '{self.workspace.root}'"
|
||||
)
|
||||
|
||||
if isinstance(args, str):
|
||||
args = args.split() # Convert space-separated string to a list
|
||||
|
||||
if not str(filename).endswith(".py"):
|
||||
raise InvalidArgumentError("Invalid file type. Only .py files are allowed.")
|
||||
|
||||
file_path = self.workspace.get_path(filename)
|
||||
if not self.workspace.exists(file_path):
|
||||
# Mimic the response that you get from the command line to make it
|
||||
# intuitively understandable for the LLM
|
||||
raise FileNotFoundError(
|
||||
f"python: can't open file '{filename}': "
|
||||
f"[Errno 2] No such file or directory"
|
||||
)
|
||||
|
||||
if we_are_running_in_a_docker_container():
|
||||
logger.debug(
|
||||
"AutoGPT is running in a Docker container; "
|
||||
f"executing {file_path} directly..."
|
||||
)
|
||||
result = subprocess.run(
|
||||
["python", "-B", str(file_path)] + args,
|
||||
capture_output=True,
|
||||
encoding="utf8",
|
||||
cwd=str(self.workspace.root),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
else:
|
||||
raise CodeExecutionError(result.stderr)
|
||||
|
||||
logger.debug("AutoGPT is not running in a Docker container")
|
||||
try:
|
||||
assert self.state.agent_id, "Need Agent ID to attach Docker container"
|
||||
|
||||
client = docker.from_env()
|
||||
image_name = "python:3-alpine"
|
||||
container_is_fresh = False
|
||||
container_name = f"{self.state.agent_id}_sandbox"
|
||||
try:
|
||||
container: DockerContainer = client.containers.get(
|
||||
container_name
|
||||
) # type: ignore
|
||||
except NotFound:
|
||||
try:
|
||||
client.images.get(image_name)
|
||||
logger.debug(f"Image '{image_name}' found locally")
|
||||
except ImageNotFound:
|
||||
logger.info(
|
||||
f"Image '{image_name}' not found locally,"
|
||||
" pulling from Docker Hub..."
|
||||
)
|
||||
# Use the low-level API to stream the pull response
|
||||
low_level_client = docker.APIClient()
|
||||
for line in low_level_client.pull(
|
||||
image_name, stream=True, decode=True
|
||||
):
|
||||
# Print the status and progress, if available
|
||||
status = line.get("status")
|
||||
progress = line.get("progress")
|
||||
if status and progress:
|
||||
logger.info(f"{status}: {progress}")
|
||||
elif status:
|
||||
logger.info(status)
|
||||
|
||||
logger.debug(f"Creating new {image_name} container...")
|
||||
container: DockerContainer = client.containers.run(
|
||||
image_name,
|
||||
["sleep", "60"], # Max 60 seconds to prevent permanent hangs
|
||||
volumes={
|
||||
str(self.workspace.root): {
|
||||
"bind": "/workspace",
|
||||
"mode": "rw",
|
||||
}
|
||||
},
|
||||
working_dir="/workspace",
|
||||
stderr=True,
|
||||
stdout=True,
|
||||
detach=True,
|
||||
name=container_name,
|
||||
) # type: ignore
|
||||
container_is_fresh = True
|
||||
|
||||
if not container.status == "running":
|
||||
container.start()
|
||||
elif not container_is_fresh:
|
||||
container.restart()
|
||||
|
||||
logger.debug(f"Running {file_path} in container {container.name}...")
|
||||
exec_result = container.exec_run(
|
||||
[
|
||||
"python",
|
||||
"-B",
|
||||
file_path.relative_to(self.workspace.root).as_posix(),
|
||||
]
|
||||
+ args,
|
||||
stderr=True,
|
||||
stdout=True,
|
||||
detach=True,
|
||||
name=container_name,
|
||||
) # type: ignore
|
||||
container_is_fresh = True
|
||||
)
|
||||
|
||||
if not container.status == "running":
|
||||
container.start()
|
||||
elif not container_is_fresh:
|
||||
container.restart()
|
||||
if exec_result.exit_code != 0:
|
||||
raise CodeExecutionError(exec_result.output.decode("utf-8"))
|
||||
|
||||
logger.debug(f"Running {file_path} in container {container.name}...")
|
||||
exec_result = container.exec_run(
|
||||
[
|
||||
"python",
|
||||
"-B",
|
||||
file_path.relative_to(agent.workspace.root).as_posix(),
|
||||
]
|
||||
+ args,
|
||||
stderr=True,
|
||||
stdout=True,
|
||||
return exec_result.output.decode("utf-8")
|
||||
|
||||
except DockerException as e:
|
||||
logger.warning(
|
||||
"Could not run the script in a container. "
|
||||
"If you haven't already, please install Docker: "
|
||||
"https://docs.docker.com/get-docker/"
|
||||
)
|
||||
raise CommandExecutionError(f"Could not run the script in a container: {e}")
|
||||
|
||||
def validate_command(self, command_line: str, config: Config) -> tuple[bool, bool]:
|
||||
"""Check whether a command is allowed and whether it may be executed in a shell.
|
||||
|
||||
If shell command control is enabled, we disallow executing in a shell, because
|
||||
otherwise the model could circumvent the command filter using shell features.
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to validate
|
||||
config (Config): The app config including shell command control settings
|
||||
|
||||
Returns:
|
||||
bool: True if the command is allowed, False otherwise
|
||||
bool: True if the command may be executed in a shell, False otherwise
|
||||
"""
|
||||
if not command_line:
|
||||
return False, False
|
||||
|
||||
command_name = shlex.split(command_line)[0]
|
||||
|
||||
if config.shell_command_control == ALLOWLIST_CONTROL:
|
||||
return command_name in config.shell_allowlist, False
|
||||
elif config.shell_command_control == DENYLIST_CONTROL:
|
||||
return command_name not in config.shell_denylist, False
|
||||
else:
|
||||
return True, True
|
||||
|
||||
@command(
|
||||
["execute_shell"],
|
||||
"Execute a Shell Command, non-interactive commands only",
|
||||
{
|
||||
"command_line": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The command line to execute",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def execute_shell(self, command_line: str) -> str:
|
||||
"""Execute a shell command and return the output
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to execute
|
||||
|
||||
Returns:
|
||||
str: The output of the command
|
||||
"""
|
||||
allow_execute, allow_shell = self.validate_command(
|
||||
command_line, self.legacy_config
|
||||
)
|
||||
if not allow_execute:
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
|
||||
current_dir = Path.cwd()
|
||||
# Change dir into workspace if necessary
|
||||
if not current_dir.is_relative_to(self.workspace.root):
|
||||
os.chdir(self.workspace.root)
|
||||
|
||||
logger.info(
|
||||
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
|
||||
)
|
||||
|
||||
if exec_result.exit_code != 0:
|
||||
raise CodeExecutionError(exec_result.output.decode("utf-8"))
|
||||
|
||||
return exec_result.output.decode("utf-8")
|
||||
|
||||
except DockerException as e:
|
||||
logger.warning(
|
||||
"Could not run the script in a container. "
|
||||
"If you haven't already, please install Docker: "
|
||||
"https://docs.docker.com/get-docker/"
|
||||
result = subprocess.run(
|
||||
command_line if allow_shell else shlex.split(command_line),
|
||||
capture_output=True,
|
||||
shell=allow_shell,
|
||||
)
|
||||
raise CommandExecutionError(f"Could not run the script in a container: {e}")
|
||||
output = f"STDOUT:\n{result.stdout.decode()}\nSTDERR:\n{result.stderr.decode()}"
|
||||
|
||||
# Change back to whatever the prior working dir was
|
||||
os.chdir(current_dir)
|
||||
|
||||
def validate_command(command_line: str, config: Config) -> tuple[bool, bool]:
|
||||
"""Check whether a command is allowed and whether it may be executed in a shell.
|
||||
return output
|
||||
|
||||
If shell command control is enabled, we disallow executing in a shell, because
|
||||
otherwise the model could easily circumvent the command filter using shell features.
|
||||
@command(
|
||||
["execute_shell_popen"],
|
||||
"Execute a Shell Command, non-interactive commands only",
|
||||
{
|
||||
"command_line": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The command line to execute",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def execute_shell_popen(self, command_line: str) -> str:
|
||||
"""Execute a shell command with Popen and returns an english description
|
||||
of the event and the process id
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to validate
|
||||
config (Config): The application config including shell command control settings
|
||||
Args:
|
||||
command_line (str): The command line to execute
|
||||
|
||||
Returns:
|
||||
bool: True if the command is allowed, False otherwise
|
||||
bool: True if the command may be executed in a shell, False otherwise
|
||||
"""
|
||||
if not command_line:
|
||||
return False, False
|
||||
|
||||
command_name = shlex.split(command_line)[0]
|
||||
|
||||
if config.shell_command_control == ALLOWLIST_CONTROL:
|
||||
return command_name in config.shell_allowlist, False
|
||||
elif config.shell_command_control == DENYLIST_CONTROL:
|
||||
return command_name not in config.shell_denylist, False
|
||||
else:
|
||||
return True, True
|
||||
|
||||
|
||||
@command(
|
||||
"execute_shell",
|
||||
"Execute a Shell Command, non-interactive commands only",
|
||||
{
|
||||
"command_line": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The command line to execute",
|
||||
required=True,
|
||||
Returns:
|
||||
str: Description of the fact that the process started and its id
|
||||
"""
|
||||
allow_execute, allow_shell = self.validate_command(
|
||||
command_line, self.legacy_config
|
||||
)
|
||||
},
|
||||
enabled=lambda config: config.execute_local_commands,
|
||||
disabled_reason="You are not allowed to run local shell commands. To execute"
|
||||
" shell commands, EXECUTE_LOCAL_COMMANDS must be set to 'True' "
|
||||
"in your config file: .env - do not attempt to bypass the restriction.",
|
||||
)
|
||||
def execute_shell(command_line: str, agent: Agent) -> str:
|
||||
"""Execute a shell command and return the output
|
||||
if not allow_execute:
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to execute
|
||||
current_dir = Path.cwd()
|
||||
# Change dir into workspace if necessary
|
||||
if not current_dir.is_relative_to(self.workspace.root):
|
||||
os.chdir(self.workspace.root)
|
||||
|
||||
Returns:
|
||||
str: The output of the command
|
||||
"""
|
||||
allow_execute, allow_shell = validate_command(command_line, agent.legacy_config)
|
||||
if not allow_execute:
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
|
||||
current_dir = Path.cwd()
|
||||
# Change dir into workspace if necessary
|
||||
if not current_dir.is_relative_to(agent.workspace.root):
|
||||
os.chdir(agent.workspace.root)
|
||||
|
||||
logger.info(
|
||||
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
|
||||
)
|
||||
|
||||
result = subprocess.run(
|
||||
command_line if allow_shell else shlex.split(command_line),
|
||||
capture_output=True,
|
||||
shell=allow_shell,
|
||||
)
|
||||
output = f"STDOUT:\n{result.stdout.decode()}\nSTDERR:\n{result.stderr.decode()}"
|
||||
|
||||
# Change back to whatever the prior working dir was
|
||||
os.chdir(current_dir)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@command(
|
||||
"execute_shell_popen",
|
||||
"Execute a Shell Command, non-interactive commands only",
|
||||
{
|
||||
"command_line": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The command line to execute",
|
||||
required=True,
|
||||
logger.info(
|
||||
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
|
||||
)
|
||||
},
|
||||
lambda config: config.execute_local_commands,
|
||||
"You are not allowed to run local shell commands. To execute"
|
||||
" shell commands, EXECUTE_LOCAL_COMMANDS must be set to 'True' "
|
||||
"in your config. Do not attempt to bypass the restriction.",
|
||||
)
|
||||
def execute_shell_popen(command_line: str, agent: Agent) -> str:
|
||||
"""Execute a shell command with Popen and returns an english description
|
||||
of the event and the process id
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to execute
|
||||
do_not_show_output = subprocess.DEVNULL
|
||||
process = subprocess.Popen(
|
||||
command_line if allow_shell else shlex.split(command_line),
|
||||
shell=allow_shell,
|
||||
stdout=do_not_show_output,
|
||||
stderr=do_not_show_output,
|
||||
)
|
||||
|
||||
Returns:
|
||||
str: Description of the fact that the process started and its id
|
||||
"""
|
||||
allow_execute, allow_shell = validate_command(command_line, agent.legacy_config)
|
||||
if not allow_execute:
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
# Change back to whatever the prior working dir was
|
||||
os.chdir(current_dir)
|
||||
|
||||
current_dir = Path.cwd()
|
||||
# Change dir into workspace if necessary
|
||||
if not current_dir.is_relative_to(agent.workspace.root):
|
||||
os.chdir(agent.workspace.root)
|
||||
|
||||
logger.info(
|
||||
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
|
||||
)
|
||||
|
||||
do_not_show_output = subprocess.DEVNULL
|
||||
process = subprocess.Popen(
|
||||
command_line if allow_shell else shlex.split(command_line),
|
||||
shell=allow_shell,
|
||||
stdout=do_not_show_output,
|
||||
stderr=do_not_show_output,
|
||||
)
|
||||
|
||||
# Change back to whatever the prior working dir was
|
||||
os.chdir(current_dir)
|
||||
|
||||
return f"Subprocess started with PID:'{str(process.pid)}'"
|
||||
return f"Subprocess started with PID:'{str(process.pid)}'"
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
"""Commands to perform operations on files"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt.agents.features.context import ContextMixin, get_agent_context
|
||||
from autogpt.agents.utils.exceptions import (
|
||||
CommandExecutionError,
|
||||
DuplicateOperationError,
|
||||
)
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.models.context_item import FileContextItem, FolderContextItem
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
|
||||
COMMAND_CATEGORY = "file_operations"
|
||||
COMMAND_CATEGORY_TITLE = "File Operations"
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents import Agent, BaseAgent
|
||||
|
||||
|
||||
def agent_implements_context(agent: BaseAgent) -> bool:
|
||||
return isinstance(agent, ContextMixin)
|
||||
|
||||
|
||||
@command(
|
||||
"open_file",
|
||||
"Opens a file for editing or continued viewing;"
|
||||
" creates it if it does not exist yet. "
|
||||
"Note: If you only need to read or write a file once, use `write_to_file` instead.",
|
||||
{
|
||||
"file_path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the file to open",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
available=agent_implements_context,
|
||||
)
|
||||
@sanitize_path_arg("file_path")
|
||||
def open_file(file_path: Path, agent: Agent) -> tuple[str, FileContextItem]:
|
||||
"""Open a file and return a context item
|
||||
|
||||
Args:
|
||||
file_path (Path): The path of the file to open
|
||||
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
FileContextItem: A ContextItem representing the opened file
|
||||
"""
|
||||
# Try to make the file path relative
|
||||
relative_file_path = None
|
||||
with contextlib.suppress(ValueError):
|
||||
relative_file_path = file_path.relative_to(agent.workspace.root)
|
||||
|
||||
assert (agent_context := get_agent_context(agent)) is not None
|
||||
|
||||
created = False
|
||||
if not file_path.exists():
|
||||
file_path.touch()
|
||||
created = True
|
||||
elif not file_path.is_file():
|
||||
raise CommandExecutionError(f"{file_path} exists but is not a file")
|
||||
|
||||
file_path = relative_file_path or file_path
|
||||
|
||||
file = FileContextItem(
|
||||
file_path_in_workspace=file_path,
|
||||
workspace_path=agent.workspace.root,
|
||||
)
|
||||
if file in agent_context:
|
||||
raise DuplicateOperationError(f"The file {file_path} is already open")
|
||||
|
||||
return (
|
||||
f"File {file_path}{' created,' if created else ''} has been opened"
|
||||
" and added to the context ✅",
|
||||
file,
|
||||
)
|
||||
|
||||
|
||||
@command(
|
||||
"open_folder",
|
||||
"Open a folder to keep track of its content",
|
||||
{
|
||||
"path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the folder to open",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
available=agent_implements_context,
|
||||
)
|
||||
@sanitize_path_arg("path")
|
||||
def open_folder(path: Path, agent: Agent) -> tuple[str, FolderContextItem]:
|
||||
"""Open a folder and return a context item
|
||||
|
||||
Args:
|
||||
path (Path): The path of the folder to open
|
||||
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
FolderContextItem: A ContextItem representing the opened folder
|
||||
"""
|
||||
# Try to make the path relative
|
||||
relative_path = None
|
||||
with contextlib.suppress(ValueError):
|
||||
relative_path = path.relative_to(agent.workspace.root)
|
||||
|
||||
assert (agent_context := get_agent_context(agent)) is not None
|
||||
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"open_folder {path} failed: no such file or directory")
|
||||
elif not path.is_dir():
|
||||
raise CommandExecutionError(f"{path} exists but is not a folder")
|
||||
|
||||
path = relative_path or path
|
||||
|
||||
folder = FolderContextItem(
|
||||
path_in_workspace=path,
|
||||
workspace_path=agent.workspace.root,
|
||||
)
|
||||
if folder in agent_context:
|
||||
raise DuplicateOperationError(f"The folder {path} is already open")
|
||||
|
||||
return f"Folder {path} has been opened and added to the context ✅", folder
|
||||
@@ -1,241 +0,0 @@
|
||||
"""Commands to perform operations on files"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import os.path
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Literal
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import DuplicateOperationError
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.memory.vector import MemoryItemFactory, VectorMemory
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
from .file_operations_utils import decode_textual_file
|
||||
|
||||
COMMAND_CATEGORY = "file_operations"
|
||||
COMMAND_CATEGORY_TITLE = "File Operations"
|
||||
|
||||
|
||||
from .file_context import open_file, open_folder # NOQA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Operation = Literal["write", "append", "delete"]
|
||||
|
||||
|
||||
def text_checksum(text: str) -> str:
|
||||
"""Get the hex checksum for the given text."""
|
||||
return hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def operations_from_log(
|
||||
logs: list[str],
|
||||
) -> Iterator[
|
||||
tuple[Literal["write", "append"], str, str] | tuple[Literal["delete"], str, None]
|
||||
]:
|
||||
"""Parse logs and return a tuple containing the log entries"""
|
||||
for line in logs:
|
||||
line = line.replace("File Operation Logger", "").strip()
|
||||
if not line:
|
||||
continue
|
||||
operation, tail = line.split(": ", maxsplit=1)
|
||||
operation = operation.strip()
|
||||
if operation in ("write", "append"):
|
||||
path, checksum = (x.strip() for x in tail.rsplit(" #", maxsplit=1))
|
||||
yield (operation, path, checksum)
|
||||
elif operation == "delete":
|
||||
yield (operation, tail.strip(), None)
|
||||
|
||||
|
||||
def file_operations_state(logs: list[str]) -> dict[str, str]:
|
||||
"""Iterates over the operations and returns the expected state.
|
||||
|
||||
Constructs a dictionary that maps each file path written
|
||||
or appended to its checksum. Deleted files are
|
||||
removed from the dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping file paths to their checksums.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file_manager.file_ops_log_path is not found.
|
||||
ValueError: If the log file content is not in the expected format.
|
||||
"""
|
||||
state = {}
|
||||
for operation, path, checksum in operations_from_log(logs):
|
||||
if operation in ("write", "append"):
|
||||
state[path] = checksum
|
||||
elif operation == "delete":
|
||||
del state[path]
|
||||
return state
|
||||
|
||||
|
||||
@sanitize_path_arg("file_path", make_relative=True)
|
||||
def is_duplicate_operation(
|
||||
operation: Operation, file_path: Path, agent: Agent, checksum: str | None = None
|
||||
) -> bool:
|
||||
"""Check if the operation has already been performed
|
||||
|
||||
Args:
|
||||
operation: The operation to check for
|
||||
file_path: The name of the file to check for
|
||||
agent: The agent
|
||||
checksum: The checksum of the contents to be written
|
||||
|
||||
Returns:
|
||||
True if the operation has already been performed on the file
|
||||
"""
|
||||
state = file_operations_state(agent.get_file_operation_lines())
|
||||
if operation == "delete" and file_path.as_posix() not in state:
|
||||
return True
|
||||
if operation == "write" and state.get(file_path.as_posix()) == checksum:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@sanitize_path_arg("file_path", make_relative=True)
|
||||
async def log_operation(
|
||||
operation: Operation,
|
||||
file_path: str | Path,
|
||||
agent: Agent,
|
||||
checksum: str | None = None,
|
||||
) -> None:
|
||||
"""Log the file operation to the file_logger.log
|
||||
|
||||
Args:
|
||||
operation: The operation to log
|
||||
file_path: The name of the file the operation was performed on
|
||||
checksum: The checksum of the contents to be written
|
||||
"""
|
||||
log_entry = (
|
||||
f"{operation}: "
|
||||
f"{file_path.as_posix() if isinstance(file_path, Path) else file_path}"
|
||||
)
|
||||
if checksum is not None:
|
||||
log_entry += f" #{checksum}"
|
||||
logger.debug(f"Logging file operation: {log_entry}")
|
||||
await agent.log_file_operation(log_entry)
|
||||
|
||||
|
||||
@command(
|
||||
"read_file",
|
||||
"Read an existing file",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the file to read",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def read_file(filename: str | Path, agent: Agent) -> str:
|
||||
"""Read a file and return the contents
|
||||
|
||||
Args:
|
||||
filename (Path): The name of the file to read
|
||||
|
||||
Returns:
|
||||
str: The contents of the file
|
||||
"""
|
||||
file = agent.workspace.open_file(filename, binary=True)
|
||||
content = decode_textual_file(file, os.path.splitext(filename)[1], logger)
|
||||
|
||||
# # TODO: invalidate/update memory when file is edited
|
||||
# file_memory = MemoryItem.from_text_file(content, str(filename), agent.config)
|
||||
# if len(file_memory.chunks) > 1:
|
||||
# return file_memory.summary
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def ingest_file(
|
||||
filename: str,
|
||||
memory: VectorMemory,
|
||||
) -> None:
|
||||
"""
|
||||
Ingest a file by reading its content, splitting it into chunks with a specified
|
||||
maximum length and overlap, and adding the chunks to the memory storage.
|
||||
|
||||
Args:
|
||||
filename: The name of the file to ingest
|
||||
memory: An object with an add() method to store the chunks in memory
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Ingesting file {filename}")
|
||||
content = read_file(filename)
|
||||
|
||||
# TODO: differentiate between different types of files
|
||||
file_memory = MemoryItemFactory.from_text_file(content, filename)
|
||||
logger.debug(f"Created memory: {file_memory.dump(True)}")
|
||||
memory.add(file_memory)
|
||||
|
||||
logger.info(f"Ingested {len(file_memory.e_chunks)} chunks from {filename}")
|
||||
except Exception as err:
|
||||
logger.warning(f"Error while ingesting file '{filename}': {err}")
|
||||
|
||||
|
||||
@command(
|
||||
"write_file",
|
||||
"Write a file, creating it if necessary. If the file exists, it is overwritten.",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The name of the file to write to",
|
||||
required=True,
|
||||
),
|
||||
"contents": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The contents to write to the file",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
aliases=["create_file"],
|
||||
)
|
||||
async def write_to_file(filename: str | Path, contents: str, agent: Agent) -> str:
|
||||
"""Write contents to a file
|
||||
|
||||
Args:
|
||||
filename (Path): The name of the file to write to
|
||||
contents (str): The contents to write to the file
|
||||
|
||||
Returns:
|
||||
str: A message indicating success or failure
|
||||
"""
|
||||
checksum = text_checksum(contents)
|
||||
if is_duplicate_operation("write", Path(filename), agent, checksum):
|
||||
raise DuplicateOperationError(f"File {filename} has already been updated.")
|
||||
|
||||
if directory := os.path.dirname(filename):
|
||||
agent.workspace.make_dir(directory)
|
||||
await agent.workspace.write_file(filename, contents)
|
||||
await log_operation("write", filename, agent, checksum)
|
||||
return f"File {filename} has been written successfully."
|
||||
|
||||
|
||||
@command(
|
||||
"list_folder",
|
||||
"List the items in a folder",
|
||||
{
|
||||
"folder": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The folder to list files in",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def list_folder(folder: str | Path, agent: Agent) -> list[str]:
|
||||
"""Lists files in a folder recursively
|
||||
|
||||
Args:
|
||||
folder (Path): The folder to search in
|
||||
|
||||
Returns:
|
||||
list[str]: A list of files found in the folder
|
||||
"""
|
||||
return [str(p) for p in agent.workspace.list_files(folder)]
|
||||
@@ -1,58 +1,61 @@
|
||||
"""Commands to perform Git operations"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from git.repo import Repo
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import CommandExecutionError
|
||||
from autogpt.agents.protocols import CommandProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.url_utils.validators import validate_url
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
|
||||
COMMAND_CATEGORY = "git_operations"
|
||||
COMMAND_CATEGORY_TITLE = "Git Operations"
|
||||
from autogpt.utils.exceptions import CommandExecutionError
|
||||
|
||||
|
||||
@command(
|
||||
"clone_repository",
|
||||
"Clones a Repository",
|
||||
{
|
||||
"url": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The URL of the repository to clone",
|
||||
required=True,
|
||||
),
|
||||
"clone_path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path to clone the repository to",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
lambda config: bool(config.github_username and config.github_api_key),
|
||||
"Configure github_username and github_api_key.",
|
||||
)
|
||||
@sanitize_path_arg("clone_path")
|
||||
@validate_url
|
||||
def clone_repository(url: str, clone_path: Path, agent: Agent) -> str:
|
||||
"""Clone a GitHub repository locally.
|
||||
class GitOperationsComponent(CommandProvider):
|
||||
"""Provides commands to perform Git operations."""
|
||||
|
||||
Args:
|
||||
url (str): The URL of the repository to clone.
|
||||
clone_path (Path): The path to clone the repository to.
|
||||
def __init__(self, config: Config):
|
||||
self._enabled = bool(config.github_username and config.github_api_key)
|
||||
self._disabled_reason = "Configure github_username and github_api_key."
|
||||
self.legacy_config = config
|
||||
|
||||
Returns:
|
||||
str: The result of the clone operation.
|
||||
"""
|
||||
split_url = url.split("//")
|
||||
auth_repo_url = f"//{agent.legacy_config.github_username}:{agent.legacy_config.github_api_key}@".join( # noqa: E501
|
||||
split_url
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.clone_repository
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"url": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The URL of the repository to clone",
|
||||
required=True,
|
||||
),
|
||||
"clone_path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path to clone the repository to",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
try:
|
||||
Repo.clone_from(url=auth_repo_url, to_path=clone_path)
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(f"Could not clone repo: {e}")
|
||||
@validate_url
|
||||
def clone_repository(self, url: str, clone_path: Path) -> str:
|
||||
"""Clone a GitHub repository locally.
|
||||
|
||||
return f"""Cloned {url} to {clone_path}"""
|
||||
Args:
|
||||
url (str): The URL of the repository to clone.
|
||||
clone_path (Path): The path to clone the repository to.
|
||||
|
||||
Returns:
|
||||
str: The result of the clone operation.
|
||||
"""
|
||||
split_url = url.split("//")
|
||||
auth_repo_url = (
|
||||
f"//{self.legacy_config.github_username}:"
|
||||
f"{self.legacy_config.github_api_key}@".join(split_url)
|
||||
)
|
||||
try:
|
||||
Repo.clone_from(url=auth_repo_url, to_path=clone_path)
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(f"Could not clone repo: {e}")
|
||||
|
||||
return f"""Cloned {url} to {clone_path}"""
|
||||
|
||||
@@ -7,206 +7,216 @@ import time
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.protocols import CommandProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
COMMAND_CATEGORY = "text_to_image"
|
||||
COMMAND_CATEGORY_TITLE = "Text to Image"
|
||||
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.models.command import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
"generate_image",
|
||||
"Generates an Image",
|
||||
{
|
||||
"prompt": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The prompt used to generate the image",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
lambda config: bool(config.image_provider),
|
||||
"Requires a image provider to be set.",
|
||||
)
|
||||
def generate_image(prompt: str, agent: Agent, size: int = 256) -> str:
|
||||
"""Generate an image from a prompt.
|
||||
class ImageGeneratorComponent(CommandProvider):
|
||||
"""A component that provides commands to generate images from text prompts."""
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
Not supported by HuggingFace.
|
||||
def __init__(self, workspace: FileStorage, config: Config):
|
||||
self._enabled = bool(config.image_provider)
|
||||
self._disabled_reason = "No image provider set."
|
||||
self.workspace = workspace
|
||||
self.legacy_config = config
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = agent.workspace.root / f"{str(uuid.uuid4())}.jpg"
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.generate_image
|
||||
|
||||
# DALL-E
|
||||
if agent.legacy_config.image_provider == "dalle":
|
||||
return generate_image_with_dalle(prompt, filename, size, agent)
|
||||
# HuggingFace
|
||||
elif agent.legacy_config.image_provider == "huggingface":
|
||||
return generate_image_with_hf(prompt, filename, agent)
|
||||
# SD WebUI
|
||||
elif agent.legacy_config.image_provider == "sdwebui":
|
||||
return generate_image_with_sd_webui(prompt, filename, agent, size)
|
||||
return "No Image Provider Set"
|
||||
@command(
|
||||
parameters={
|
||||
"prompt": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The prompt used to generate the image",
|
||||
required=True,
|
||||
),
|
||||
"size": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The size of the image",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def generate_image(self, prompt: str, size: int) -> str:
|
||||
"""Generate an image from a prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
Not supported by HuggingFace.
|
||||
|
||||
def generate_image_with_hf(prompt: str, output_file: Path, agent: Agent) -> str:
|
||||
"""Generate an image with HuggingFace's API.
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = self.workspace.root / f"{str(uuid.uuid4())}.jpg"
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (Path): The filename to save the image to
|
||||
# DALL-E
|
||||
if self.legacy_config.image_provider == "dalle":
|
||||
return self.generate_image_with_dalle(prompt, filename, size)
|
||||
# HuggingFace
|
||||
elif self.legacy_config.image_provider == "huggingface":
|
||||
return self.generate_image_with_hf(prompt, filename)
|
||||
# SD WebUI
|
||||
elif self.legacy_config.image_provider == "sdwebui":
|
||||
return self.generate_image_with_sd_webui(prompt, filename, size)
|
||||
return "No Image Provider Set"
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
API_URL = f"https://api-inference.huggingface.co/models/{agent.legacy_config.huggingface_image_model}" # noqa: E501
|
||||
if agent.legacy_config.huggingface_api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
def generate_image_with_hf(self, prompt: str, output_file: Path) -> str:
|
||||
"""Generate an image with HuggingFace's API.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (Path): The filename to save the image to
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
API_URL = f"https://api-inference.huggingface.co/models/{self.legacy_config.huggingface_image_model}" # noqa: E501
|
||||
if self.legacy_config.huggingface_api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.legacy_config.huggingface_api_token}",
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
retry_count = 0
|
||||
while retry_count < 10:
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
headers=headers,
|
||||
json={
|
||||
"inputs": prompt,
|
||||
},
|
||||
)
|
||||
|
||||
if response.ok:
|
||||
try:
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
image.save(output_file)
|
||||
return f"Saved to disk: {output_file}"
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
else:
|
||||
try:
|
||||
error = json.loads(response.text)
|
||||
if "estimated_time" in error:
|
||||
delay = error["estimated_time"]
|
||||
logger.debug(response.text)
|
||||
logger.info("Retrying in", delay)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
|
||||
retry_count += 1
|
||||
|
||||
return "Error creating image."
|
||||
|
||||
def generate_image_with_dalle(
|
||||
self, prompt: str, output_file: Path, size: int
|
||||
) -> str:
|
||||
"""Generate an image with DALL-E.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (Path): The filename to save the image to
|
||||
size (int): The size of the image
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
|
||||
# Check for supported image sizes
|
||||
if size not in [256, 512, 1024]:
|
||||
closest = min([256, 512, 1024], key=lambda x: abs(x - size))
|
||||
logger.info(
|
||||
"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. "
|
||||
f"Setting to {closest}, was {size}."
|
||||
)
|
||||
size = closest
|
||||
|
||||
response = OpenAI(
|
||||
api_key=self.legacy_config.openai_credentials.api_key.get_secret_value()
|
||||
).images.generate(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size=f"{size}x{size}",
|
||||
response_format="b64_json",
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {agent.legacy_config.huggingface_api_token}",
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
retry_count = 0
|
||||
while retry_count < 10:
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
image_data = b64decode(response.data[0].b64_json)
|
||||
|
||||
with open(output_file, mode="wb") as png:
|
||||
png.write(image_data)
|
||||
|
||||
return f"Saved to disk: {output_file}"
|
||||
|
||||
def generate_image_with_sd_webui(
|
||||
self,
|
||||
prompt: str,
|
||||
output_file: Path,
|
||||
size: int = 512,
|
||||
negative_prompt: str = "",
|
||||
extra: dict = {},
|
||||
) -> str:
|
||||
"""Generate an image with Stable Diffusion webui.
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
|
||||
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
# Create a session and set the basic auth if needed
|
||||
s = requests.Session()
|
||||
if self.legacy_config.sd_webui_auth:
|
||||
username, password = self.legacy_config.sd_webui_auth.split(":")
|
||||
s.auth = (username, password or "")
|
||||
|
||||
# Generate the images
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
headers=headers,
|
||||
f"{self.legacy_config.sd_webui_url}/sdapi/v1/txt2img",
|
||||
json={
|
||||
"inputs": prompt,
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"sampler_index": "DDIM",
|
||||
"steps": 20,
|
||||
"config_scale": 7.0,
|
||||
"width": size,
|
||||
"height": size,
|
||||
"n_iter": 1,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
if response.ok:
|
||||
try:
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
image.save(output_file)
|
||||
return f"Saved to disk: {output_file}"
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
else:
|
||||
try:
|
||||
error = json.loads(response.text)
|
||||
if "estimated_time" in error:
|
||||
delay = error["estimated_time"]
|
||||
logger.debug(response.text)
|
||||
logger.info("Retrying in", delay)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
logger.info(f"Image Generated for prompt: '{prompt}'")
|
||||
|
||||
retry_count += 1
|
||||
# Save the image to disk
|
||||
response = response.json()
|
||||
b64 = b64decode(response["images"][0].split(",", 1)[0])
|
||||
image = Image.open(io.BytesIO(b64))
|
||||
image.save(output_file)
|
||||
|
||||
return "Error creating image."
|
||||
|
||||
|
||||
def generate_image_with_dalle(
|
||||
prompt: str, output_file: Path, size: int, agent: Agent
|
||||
) -> str:
|
||||
"""Generate an image with DALL-E.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (Path): The filename to save the image to
|
||||
size (int): The size of the image
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
|
||||
# Check for supported image sizes
|
||||
if size not in [256, 512, 1024]:
|
||||
closest = min([256, 512, 1024], key=lambda x: abs(x - size))
|
||||
logger.info(
|
||||
"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. "
|
||||
f"Setting to {closest}, was {size}."
|
||||
)
|
||||
size = closest
|
||||
|
||||
response = OpenAI(
|
||||
api_key=agent.legacy_config.openai_credentials.api_key.get_secret_value()
|
||||
).images.generate(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size=f"{size}x{size}",
|
||||
response_format="b64_json",
|
||||
)
|
||||
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
image_data = b64decode(response.data[0].b64_json)
|
||||
|
||||
with open(output_file, mode="wb") as png:
|
||||
png.write(image_data)
|
||||
|
||||
return f"Saved to disk: {output_file}"
|
||||
|
||||
|
||||
def generate_image_with_sd_webui(
|
||||
prompt: str,
|
||||
output_file: Path,
|
||||
agent: Agent,
|
||||
size: int = 512,
|
||||
negative_prompt: str = "",
|
||||
extra: dict = {},
|
||||
) -> str:
|
||||
"""Generate an image with Stable Diffusion webui.
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
|
||||
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
# Create a session and set the basic auth if needed
|
||||
s = requests.Session()
|
||||
if agent.legacy_config.sd_webui_auth:
|
||||
username, password = agent.legacy_config.sd_webui_auth.split(":")
|
||||
s.auth = (username, password or "")
|
||||
|
||||
# Generate the images
|
||||
response = requests.post(
|
||||
f"{agent.legacy_config.sd_webui_url}/sdapi/v1/txt2img",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"sampler_index": "DDIM",
|
||||
"steps": 20,
|
||||
"config_scale": 7.0,
|
||||
"width": size,
|
||||
"height": size,
|
||||
"n_iter": 1,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Image Generated for prompt: '{prompt}'")
|
||||
|
||||
# Save the image to disk
|
||||
response = response.json()
|
||||
b64 = b64decode(response["images"][0].split(",", 1)[0])
|
||||
image = Image.open(io.BytesIO(b64))
|
||||
image.save(output_file)
|
||||
|
||||
return f"Saved to disk: {output_file}"
|
||||
return f"Saved to disk: {output_file}"
|
||||
|
||||
@@ -1,69 +1,55 @@
|
||||
"""Commands to control the internal state of the program"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
from autogpt.agents.features.context import get_agent_context
|
||||
from autogpt.agents.utils.exceptions import AgentFinished, InvalidArgumentError
|
||||
from autogpt.agents.protocols import CommandProvider, DirectiveProvider, MessageProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config.ai_profile import AIProfile
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.resource.model_providers.schema import ChatMessage
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
COMMAND_CATEGORY = "system"
|
||||
COMMAND_CATEGORY_TITLE = "System"
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.utils.exceptions import AgentFinished
|
||||
from autogpt.utils.utils import DEFAULT_FINISH_COMMAND
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
"finish",
|
||||
"Use this to shut down once you have completed your task,"
|
||||
" or when there are insurmountable problems that make it impossible"
|
||||
" for you to finish your task.",
|
||||
{
|
||||
"reason": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="A summary to the user of how the goals were accomplished",
|
||||
required=True,
|
||||
class SystemComponent(DirectiveProvider, MessageProvider, CommandProvider):
|
||||
"""Component for system messages and commands."""
|
||||
|
||||
def __init__(self, config: Config, profile: AIProfile):
|
||||
self.legacy_config = config
|
||||
self.profile = profile
|
||||
|
||||
def get_constraints(self) -> Iterator[str]:
|
||||
if self.profile.api_budget > 0.0:
|
||||
yield (
|
||||
f"It takes money to let you run. "
|
||||
f"Your API budget is ${self.profile.api_budget:.3f}"
|
||||
)
|
||||
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
# Clock
|
||||
yield ChatMessage.system(
|
||||
f"## Clock\nThe current time and date is {time.strftime('%c')}"
|
||||
)
|
||||
},
|
||||
)
|
||||
def finish(reason: str, agent: Agent) -> None:
|
||||
"""
|
||||
A function that takes in a string and exits the program
|
||||
|
||||
Parameters:
|
||||
reason (str): A summary to the user of how the goals were accomplished.
|
||||
Returns:
|
||||
A result string from create chat completion. A list of suggestions to
|
||||
improve the code.
|
||||
"""
|
||||
raise AgentFinished(reason)
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.finish
|
||||
|
||||
|
||||
@command(
|
||||
"hide_context_item",
|
||||
"Hide an open file, folder or other context item, to save memory.",
|
||||
{
|
||||
"number": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The 1-based index of the context item to hide",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
available=lambda a: bool(get_agent_context(a)),
|
||||
)
|
||||
def close_context_item(number: int, agent: Agent) -> str:
|
||||
assert (context := get_agent_context(agent)) is not None
|
||||
|
||||
if number > len(context.items) or number == 0:
|
||||
raise InvalidArgumentError(f"Index {number} out of range")
|
||||
|
||||
context.close(number)
|
||||
return f"Context item {number} hidden ✅"
|
||||
@command(
|
||||
names=[DEFAULT_FINISH_COMMAND],
|
||||
parameters={
|
||||
"reason": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="A summary to the user of how the goals were accomplished",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
def finish(self, reason: str):
|
||||
"""Use this to shut down once you have completed your task,
|
||||
or when there are insurmountable problems that make it impossible
|
||||
for you to finish your task."""
|
||||
raise AgentFinished(reason)
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def get_datetime() -> str:
|
||||
"""Return the current date and time
|
||||
|
||||
Returns:
|
||||
str: The current date and time
|
||||
"""
|
||||
return "Current date and time: " + datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -1,32 +1,37 @@
|
||||
"""Commands to interact with the user"""
|
||||
from typing import Iterator
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.protocols import CommandProvider
|
||||
from autogpt.app.utils import clean_input
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
COMMAND_CATEGORY = "user_interaction"
|
||||
COMMAND_CATEGORY_TITLE = "User Interaction"
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.utils.utils import DEFAULT_ASK_COMMAND
|
||||
|
||||
|
||||
@command(
|
||||
"ask_user",
|
||||
(
|
||||
"If you need more details or information regarding the given goals,"
|
||||
" you can ask the user for input"
|
||||
),
|
||||
{
|
||||
"question": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The question or prompt to the user",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
enabled=lambda config: not config.noninteractive_mode,
|
||||
)
|
||||
async def ask_user(question: str, agent: Agent) -> str:
|
||||
print(f"\nQ: {question}")
|
||||
resp = clean_input(agent.legacy_config, "A:")
|
||||
return f"The user's answer: '{resp}'"
|
||||
class UserInteractionComponent(CommandProvider):
|
||||
"""Provides commands to interact with the user."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self._enabled = not config.noninteractive_mode
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.ask_user
|
||||
|
||||
@command(
|
||||
names=[DEFAULT_ASK_COMMAND],
|
||||
parameters={
|
||||
"question": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The question or prompt to the user",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def ask_user(self, question: str) -> str:
|
||||
"""If you need more details or information regarding the given goals,
|
||||
you can ask the user for input."""
|
||||
print(f"\nQ: {question}")
|
||||
resp = clean_input(self.config, "A:")
|
||||
return f"The user's answer: '{resp}'"
|
||||
|
||||
@@ -1,169 +1,195 @@
|
||||
"""Commands to search the web with"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import ConfigurationError
|
||||
from autogpt.agents.protocols import CommandProvider, DirectiveProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
COMMAND_CATEGORY = "web_search"
|
||||
COMMAND_CATEGORY_TITLE = "Web Search"
|
||||
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.utils.exceptions import ConfigurationError
|
||||
|
||||
DUCKDUCKGO_MAX_ATTEMPTS = 3
|
||||
|
||||
|
||||
@command(
|
||||
"web_search",
|
||||
"Searches the web",
|
||||
{
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
aliases=["search"],
|
||||
)
|
||||
def web_search(query: str, agent: Agent, num_results: int = 8) -> str:
|
||||
"""Return the results of a Google search
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
search_results = []
|
||||
attempts = 0
|
||||
|
||||
while attempts < DUCKDUCKGO_MAX_ATTEMPTS:
|
||||
if not query:
|
||||
return json.dumps(search_results)
|
||||
|
||||
search_results = DDGS().text(query, max_results=num_results)
|
||||
|
||||
if search_results:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
attempts += 1
|
||||
|
||||
search_results = [
|
||||
{
|
||||
"title": r["title"],
|
||||
"url": r["href"],
|
||||
**({"exerpt": r["body"]} if r.get("body") else {}),
|
||||
}
|
||||
for r in search_results
|
||||
]
|
||||
|
||||
results = (
|
||||
"## Search results\n"
|
||||
# "Read these results carefully."
|
||||
# " Extract the information you need for your task from the list of results"
|
||||
# " if possible. Otherwise, choose a webpage from the list to read entirely."
|
||||
# "\n\n"
|
||||
) + "\n\n".join(
|
||||
f"### \"{r['title']}\"\n"
|
||||
f"**URL:** {r['url']} \n"
|
||||
"**Excerpt:** " + (f'"{exerpt}"' if (exerpt := r.get("exerpt")) else "N/A")
|
||||
for r in search_results
|
||||
)
|
||||
return safe_google_results(results)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
"google",
|
||||
"Google Search",
|
||||
{
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
lambda config: bool(config.google_api_key)
|
||||
and bool(config.google_custom_search_engine_id),
|
||||
"Configure google_api_key and custom_search_engine_id.",
|
||||
aliases=["search"],
|
||||
)
|
||||
def google(query: str, agent: Agent, num_results: int = 8) -> str | list[str]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
class WebSearchComponent(DirectiveProvider, CommandProvider):
|
||||
"""Provides commands to search the web."""
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
def __init__(self, config: Config):
|
||||
self.legacy_config = config
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
try:
|
||||
# Get the Google API key and Custom Search Engine ID from the config file
|
||||
api_key = agent.legacy_config.google_api_key
|
||||
custom_search_engine_id = agent.legacy_config.google_custom_search_engine_id
|
||||
|
||||
# Initialize the Custom Search API service
|
||||
service = build("customsearch", "v1", developerKey=api_key)
|
||||
|
||||
# Send the search query and retrieve the results
|
||||
result = (
|
||||
service.cse()
|
||||
.list(q=query, cx=custom_search_engine_id, num=num_results)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
# Create a list of only the URLs from the search results
|
||||
search_results_links = [item["link"] for item in search_results]
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
error_details = json.loads(e.content.decode())
|
||||
|
||||
# Check if the error is related to an invalid or missing API key
|
||||
if error_details.get("error", {}).get(
|
||||
"code"
|
||||
) == 403 and "invalid API key" in error_details.get("error", {}).get(
|
||||
"message", ""
|
||||
if (
|
||||
not self.legacy_config.google_api_key
|
||||
or not self.legacy_config.google_custom_search_engine_id
|
||||
):
|
||||
raise ConfigurationError(
|
||||
"The provided Google API key is invalid or missing."
|
||||
logger.info(
|
||||
"Configure google_api_key and custom_search_engine_id "
|
||||
"to use Google API search."
|
||||
)
|
||||
raise
|
||||
# google_result can be a list or a string depending on the search results
|
||||
|
||||
# Return the list of search result URLs
|
||||
return safe_google_results(search_results_links)
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield "Internet access for searches and information gathering."
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.web_search
|
||||
|
||||
def safe_google_results(results: str | list) -> str:
|
||||
"""
|
||||
Return the results of a Google search in a safe format.
|
||||
if (
|
||||
self.legacy_config.google_api_key
|
||||
and self.legacy_config.google_custom_search_engine_id
|
||||
):
|
||||
yield self.google
|
||||
|
||||
Args:
|
||||
results (str | list): The search results.
|
||||
@command(
|
||||
["web_search", "search"],
|
||||
"Searches the web",
|
||||
{
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
),
|
||||
"num_results": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The number of results to return",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def web_search(self, query: str, num_results: int = 8) -> str:
|
||||
"""Return the results of a Google search
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps(
|
||||
[result.encode("utf-8", "ignore").decode("utf-8") for result in results]
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
search_results = []
|
||||
attempts = 0
|
||||
|
||||
while attempts < DUCKDUCKGO_MAX_ATTEMPTS:
|
||||
if not query:
|
||||
return json.dumps(search_results)
|
||||
|
||||
search_results = DDGS().text(query, max_results=num_results)
|
||||
|
||||
if search_results:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
attempts += 1
|
||||
|
||||
search_results = [
|
||||
{
|
||||
"title": r["title"],
|
||||
"url": r["href"],
|
||||
**({"exerpt": r["body"]} if r.get("body") else {}),
|
||||
}
|
||||
for r in search_results
|
||||
]
|
||||
|
||||
results = ("## Search results\n") + "\n\n".join(
|
||||
f"### \"{r['title']}\"\n"
|
||||
f"**URL:** {r['url']} \n"
|
||||
"**Excerpt:** " + (f'"{exerpt}"' if (exerpt := r.get("exerpt")) else "N/A")
|
||||
for r in search_results
|
||||
)
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
return self.safe_google_results(results)
|
||||
|
||||
@command(
|
||||
["google"],
|
||||
"Google Search",
|
||||
{
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
),
|
||||
"num_results": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The number of results to return",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def google(self, query: str, num_results: int = 8) -> str | list[str]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
try:
|
||||
# Get the Google API key and Custom Search Engine ID from the config file
|
||||
api_key = self.legacy_config.google_api_key
|
||||
custom_search_engine_id = self.legacy_config.google_custom_search_engine_id
|
||||
|
||||
# Initialize the Custom Search API service
|
||||
service = build("customsearch", "v1", developerKey=api_key)
|
||||
|
||||
# Send the search query and retrieve the results
|
||||
result = (
|
||||
service.cse()
|
||||
.list(q=query, cx=custom_search_engine_id, num=num_results)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
# Create a list of only the URLs from the search results
|
||||
search_results_links = [item["link"] for item in search_results]
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
error_details = json.loads(e.content.decode())
|
||||
|
||||
# Check if the error is related to an invalid or missing API key
|
||||
if error_details.get("error", {}).get(
|
||||
"code"
|
||||
) == 403 and "invalid API key" in error_details.get("error", {}).get(
|
||||
"message", ""
|
||||
):
|
||||
raise ConfigurationError(
|
||||
"The provided Google API key is invalid or missing."
|
||||
)
|
||||
raise
|
||||
# google_result can be a list or a string depending on the search results
|
||||
|
||||
# Return the list of search result URLs
|
||||
return self.safe_google_results(search_results_links)
|
||||
|
||||
def safe_google_results(self, results: str | list) -> str:
|
||||
"""
|
||||
Return the results of a Google search in a safe format.
|
||||
|
||||
Args:
|
||||
results (str | list): The search results.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps(
|
||||
[result.encode("utf-8", "ignore").decode("utf-8") for result in results]
|
||||
)
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
"""Commands for browsing a website"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from sys import platform
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
from typing import Iterator, Type
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
@@ -32,21 +28,19 @@ from webdriver_manager.chrome import ChromeDriverManager
|
||||
from webdriver_manager.firefox import GeckoDriverManager
|
||||
from webdriver_manager.microsoft import EdgeChromiumDriverManager as EdgeDriverManager
|
||||
|
||||
from autogpt.agents.utils.exceptions import CommandExecutionError, TooMuchOutputError
|
||||
from autogpt.agents.protocols import CommandProvider, DirectiveProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.processing.html import extract_hyperlinks, format_hyperlinks
|
||||
from autogpt.processing.text import extract_information, summarize_text
|
||||
from autogpt.url_utils.validators import validate_url
|
||||
|
||||
COMMAND_CATEGORY = "web_browse"
|
||||
COMMAND_CATEGORY_TITLE = "Web Browsing"
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.config import Config
|
||||
|
||||
from autogpt.utils.exceptions import CommandExecutionError, TooMuchOutputError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,321 +53,324 @@ class BrowsingError(CommandExecutionError):
|
||||
"""An error occurred while trying to browse the page"""
|
||||
|
||||
|
||||
@command(
|
||||
"read_webpage",
|
||||
(
|
||||
"Read a webpage, and extract specific information from it."
|
||||
" You must specify either topics_of_interest, a question, or get_raw_content."
|
||||
),
|
||||
{
|
||||
"url": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The URL to visit",
|
||||
required=True,
|
||||
class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
"""Provides commands to browse the web using Selenium."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
llm_provider: ChatModelProvider,
|
||||
model_info: ChatModelInfo,
|
||||
):
|
||||
self.legacy_config = config
|
||||
self.llm_provider = llm_provider
|
||||
self.model_info = model_info
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield "Ability to read websites."
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.read_webpage
|
||||
|
||||
@command(
|
||||
["read_webpage"],
|
||||
(
|
||||
"Read a webpage, and extract specific information from it."
|
||||
" You must specify either topics_of_interest,"
|
||||
" a question, or get_raw_content."
|
||||
),
|
||||
"topics_of_interest": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
items=JSONSchema(type=JSONSchema.Type.STRING),
|
||||
description=(
|
||||
"A list of topics about which you want to extract information "
|
||||
"from the page."
|
||||
{
|
||||
"url": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The URL to visit",
|
||||
required=True,
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
"question": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description=(
|
||||
"A question that you want to answer using the content of the webpage."
|
||||
"topics_of_interest": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
items=JSONSchema(type=JSONSchema.Type.STRING),
|
||||
description=(
|
||||
"A list of topics about which you want to extract information "
|
||||
"from the page."
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
"get_raw_content": JSONSchema(
|
||||
type=JSONSchema.Type.BOOLEAN,
|
||||
description=(
|
||||
"If true, the unprocessed content of the webpage will be returned. "
|
||||
"This consumes a lot of tokens, so use it with caution."
|
||||
"question": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description=(
|
||||
"A question you want to answer using the content of the webpage."
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
@validate_url
|
||||
async def read_webpage(
|
||||
url: str,
|
||||
agent: Agent,
|
||||
*,
|
||||
topics_of_interest: list[str] = [],
|
||||
get_raw_content: bool = False,
|
||||
question: str = "",
|
||||
) -> str:
|
||||
"""Browse a website and return the answer and links to the user
|
||||
"get_raw_content": JSONSchema(
|
||||
type=JSONSchema.Type.BOOLEAN,
|
||||
description=(
|
||||
"If true, the unprocessed content of the webpage will be returned. "
|
||||
"This consumes a lot of tokens, so use it with caution."
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
@validate_url
|
||||
async def read_webpage(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
topics_of_interest: list[str] = [],
|
||||
get_raw_content: bool = False,
|
||||
question: str = "",
|
||||
) -> str:
|
||||
"""Browse a website and return the answer and links to the user
|
||||
|
||||
Args:
|
||||
url (str): The url of the website to browse
|
||||
question (str): The question to answer using the content of the webpage
|
||||
Args:
|
||||
url (str): The url of the website to browse
|
||||
question (str): The question to answer using the content of the webpage
|
||||
|
||||
Returns:
|
||||
str: The answer and links to the user and the webdriver
|
||||
"""
|
||||
driver = None
|
||||
try:
|
||||
driver = await open_page_in_browser(url, agent.legacy_config)
|
||||
Returns:
|
||||
str: The answer and links to the user and the webdriver
|
||||
"""
|
||||
driver = None
|
||||
try:
|
||||
driver = await self.open_page_in_browser(url, self.legacy_config)
|
||||
|
||||
text = scrape_text_with_selenium(driver)
|
||||
links = scrape_links_with_selenium(driver, url)
|
||||
text = self.scrape_text_with_selenium(driver)
|
||||
links = self.scrape_links_with_selenium(driver, url)
|
||||
|
||||
return_literal_content = True
|
||||
summarized = False
|
||||
if not text:
|
||||
return f"Website did not contain any text.\n\nLinks: {links}"
|
||||
elif get_raw_content:
|
||||
if (
|
||||
output_tokens := agent.llm_provider.count_tokens(text, agent.llm.name)
|
||||
) > MAX_RAW_CONTENT_LENGTH:
|
||||
oversize_factor = round(output_tokens / MAX_RAW_CONTENT_LENGTH, 1)
|
||||
raise TooMuchOutputError(
|
||||
f"Page content is {oversize_factor}x the allowed length "
|
||||
"for `get_raw_content=true`"
|
||||
)
|
||||
return text + (f"\n\nLinks: {links}" if links else "")
|
||||
else:
|
||||
text = await summarize_memorize_webpage(
|
||||
url, text, question or None, topics_of_interest, agent, driver
|
||||
)
|
||||
return_literal_content = bool(question)
|
||||
summarized = True
|
||||
|
||||
# Limit links to LINKS_TO_RETURN
|
||||
if len(links) > LINKS_TO_RETURN:
|
||||
links = links[:LINKS_TO_RETURN]
|
||||
|
||||
text_fmt = f"'''{text}'''" if "\n" in text else f"'{text}'"
|
||||
links_fmt = "\n".join(f"- {link}" for link in links)
|
||||
return (
|
||||
f"Page content{' (summary)' if summarized else ''}:"
|
||||
if return_literal_content
|
||||
else "Answer gathered from webpage:"
|
||||
) + f" {text_fmt}\n\nLinks:\n{links_fmt}"
|
||||
|
||||
except WebDriverException as e:
|
||||
# These errors are often quite long and include lots of context.
|
||||
# Just grab the first line.
|
||||
msg = e.msg.split("\n")[0]
|
||||
if "net::" in msg:
|
||||
raise BrowsingError(
|
||||
"A networking error occurred while trying to load the page: %s"
|
||||
% re.sub(r"^unknown error: ", "", msg)
|
||||
)
|
||||
raise CommandExecutionError(msg)
|
||||
finally:
|
||||
if driver:
|
||||
close_browser(driver)
|
||||
|
||||
|
||||
def scrape_text_with_selenium(driver: WebDriver) -> str:
|
||||
"""Scrape text from a browser window using selenium
|
||||
|
||||
Args:
|
||||
driver (WebDriver): A driver object representing the browser window to scrape
|
||||
|
||||
Returns:
|
||||
str: the text scraped from the website
|
||||
"""
|
||||
|
||||
# Get the HTML content directly from the browser's DOM
|
||||
page_source = driver.execute_script("return document.body.outerHTML;")
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = "\n".join(chunk for chunk in chunks if chunk)
|
||||
return text
|
||||
|
||||
|
||||
def scrape_links_with_selenium(driver: WebDriver, base_url: str) -> list[str]:
|
||||
"""Scrape links from a website using selenium
|
||||
|
||||
Args:
|
||||
driver (WebDriver): A driver object representing the browser window to scrape
|
||||
base_url (str): The base URL to use for resolving relative links
|
||||
|
||||
Returns:
|
||||
List[str]: The links scraped from the website
|
||||
"""
|
||||
page_source = driver.page_source
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
hyperlinks = extract_hyperlinks(soup, base_url)
|
||||
|
||||
return format_hyperlinks(hyperlinks)
|
||||
|
||||
|
||||
async def open_page_in_browser(url: str, config: Config) -> WebDriver:
|
||||
"""Open a browser window and load a web page using Selenium
|
||||
|
||||
Params:
|
||||
url (str): The URL of the page to load
|
||||
config (Config): The applicable application configuration
|
||||
|
||||
Returns:
|
||||
driver (WebDriver): A driver object representing the browser window to scrape
|
||||
"""
|
||||
logging.getLogger("selenium").setLevel(logging.CRITICAL)
|
||||
|
||||
options_available: dict[str, Type[BrowserOptions]] = {
|
||||
"chrome": ChromeOptions,
|
||||
"edge": EdgeOptions,
|
||||
"firefox": FirefoxOptions,
|
||||
"safari": SafariOptions,
|
||||
}
|
||||
|
||||
options: BrowserOptions = options_available[config.selenium_web_browser]()
|
||||
options.add_argument(f"user-agent={config.user_agent}")
|
||||
|
||||
if isinstance(options, FirefoxOptions):
|
||||
if config.selenium_headless:
|
||||
options.headless = True
|
||||
options.add_argument("--disable-gpu")
|
||||
driver = FirefoxDriver(
|
||||
service=GeckoDriverService(GeckoDriverManager().install()), options=options
|
||||
)
|
||||
elif isinstance(options, EdgeOptions):
|
||||
driver = EdgeDriver(
|
||||
service=EdgeDriverService(EdgeDriverManager().install()), options=options
|
||||
)
|
||||
elif isinstance(options, SafariOptions):
|
||||
# Requires a bit more setup on the users end.
|
||||
# See https://developer.apple.com/documentation/webkit/testing_with_webdriver_in_safari # noqa: E501
|
||||
driver = SafariDriver(options=options)
|
||||
elif isinstance(options, ChromeOptions):
|
||||
if platform == "linux" or platform == "linux2":
|
||||
options.add_argument("--disable-dev-shm-usage")
|
||||
options.add_argument("--remote-debugging-port=9222")
|
||||
|
||||
options.add_argument("--no-sandbox")
|
||||
if config.selenium_headless:
|
||||
options.add_argument("--headless=new")
|
||||
options.add_argument("--disable-gpu")
|
||||
|
||||
_sideload_chrome_extensions(options, config.app_data_dir / "assets" / "crx")
|
||||
|
||||
if (chromium_driver_path := Path("/usr/bin/chromedriver")).exists():
|
||||
chrome_service = ChromeDriverService(str(chromium_driver_path))
|
||||
else:
|
||||
try:
|
||||
chrome_driver = ChromeDriverManager().install()
|
||||
except AttributeError as e:
|
||||
if "'NoneType' object has no attribute 'split'" in str(e):
|
||||
# https://github.com/SergeyPirogov/webdriver_manager/issues/649
|
||||
logger.critical(
|
||||
"Connecting to browser failed: is Chrome or Chromium installed?"
|
||||
return_literal_content = True
|
||||
summarized = False
|
||||
if not text:
|
||||
return f"Website did not contain any text.\n\nLinks: {links}"
|
||||
elif get_raw_content:
|
||||
if (
|
||||
output_tokens := self.llm_provider.count_tokens(
|
||||
text, self.model_info.name
|
||||
)
|
||||
raise
|
||||
chrome_service = ChromeDriverService(chrome_driver)
|
||||
driver = ChromeDriver(service=chrome_service, options=options)
|
||||
) > MAX_RAW_CONTENT_LENGTH:
|
||||
oversize_factor = round(output_tokens / MAX_RAW_CONTENT_LENGTH, 1)
|
||||
raise TooMuchOutputError(
|
||||
f"Page content is {oversize_factor}x the allowed length "
|
||||
"for `get_raw_content=true`"
|
||||
)
|
||||
return text + (f"\n\nLinks: {links}" if links else "")
|
||||
else:
|
||||
text = await self.summarize_webpage(
|
||||
text, question or None, topics_of_interest
|
||||
)
|
||||
return_literal_content = bool(question)
|
||||
summarized = True
|
||||
|
||||
driver.get(url)
|
||||
# Limit links to LINKS_TO_RETURN
|
||||
if len(links) > LINKS_TO_RETURN:
|
||||
links = links[:LINKS_TO_RETURN]
|
||||
|
||||
# Wait for page to be ready, sleep 2 seconds, wait again until page ready.
|
||||
# This allows the cookiewall squasher time to get rid of cookie walls.
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
text_fmt = f"'''{text}'''" if "\n" in text else f"'{text}'"
|
||||
links_fmt = "\n".join(f"- {link}" for link in links)
|
||||
return (
|
||||
f"Page content{' (summary)' if summarized else ''}:"
|
||||
if return_literal_content
|
||||
else "Answer gathered from webpage:"
|
||||
) + f" {text_fmt}\n\nLinks:\n{links_fmt}"
|
||||
|
||||
return driver
|
||||
except WebDriverException as e:
|
||||
# These errors are often quite long and include lots of context.
|
||||
# Just grab the first line.
|
||||
msg = e.msg.split("\n")[0] if e.msg else str(e)
|
||||
if "net::" in msg:
|
||||
raise BrowsingError(
|
||||
"A networking error occurred while trying to load the page: %s"
|
||||
% re.sub(r"^unknown error: ", "", msg)
|
||||
)
|
||||
raise CommandExecutionError(msg)
|
||||
finally:
|
||||
if driver:
|
||||
driver.close()
|
||||
|
||||
def scrape_text_with_selenium(self, driver: WebDriver) -> str:
|
||||
"""Scrape text from a browser window using selenium
|
||||
|
||||
def _sideload_chrome_extensions(options: ChromeOptions, dl_folder: Path) -> None:
|
||||
crx_download_url_template = "https://clients2.google.com/service/update2/crx?response=redirect&prodversion=49.0&acceptformat=crx3&x=id%3D{crx_id}%26installsource%3Dondemand%26uc" # noqa
|
||||
cookiewall_squasher_crx_id = "edibdbjcniadpccecjdfdjjppcpchdlm"
|
||||
adblocker_crx_id = "cjpalhdlnbpafiamejdnhcphjbkeiagm"
|
||||
Args:
|
||||
driver (WebDriver): A driver object representing
|
||||
the browser window to scrape
|
||||
|
||||
# Make sure the target folder exists
|
||||
dl_folder.mkdir(parents=True, exist_ok=True)
|
||||
Returns:
|
||||
str: the text scraped from the website
|
||||
"""
|
||||
|
||||
for crx_id in (cookiewall_squasher_crx_id, adblocker_crx_id):
|
||||
crx_path = dl_folder / f"{crx_id}.crx"
|
||||
if not crx_path.exists():
|
||||
logger.debug(f"Downloading CRX {crx_id}...")
|
||||
crx_download_url = crx_download_url_template.format(crx_id=crx_id)
|
||||
urlretrieve(crx_download_url, crx_path)
|
||||
logger.debug(f"Downloaded {crx_path.name}")
|
||||
options.add_extension(str(crx_path))
|
||||
# Get the HTML content directly from the browser's DOM
|
||||
page_source = driver.execute_script("return document.body.outerHTML;")
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
def close_browser(driver: WebDriver) -> None:
|
||||
"""Close the browser
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = "\n".join(chunk for chunk in chunks if chunk)
|
||||
return text
|
||||
|
||||
Args:
|
||||
driver (WebDriver): The webdriver to close
|
||||
def scrape_links_with_selenium(self, driver: WebDriver, base_url: str) -> list[str]:
|
||||
"""Scrape links from a website using selenium
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
driver.quit()
|
||||
Args:
|
||||
driver (WebDriver): A driver object representing
|
||||
the browser window to scrape
|
||||
base_url (str): The base URL to use for resolving relative links
|
||||
|
||||
Returns:
|
||||
List[str]: The links scraped from the website
|
||||
"""
|
||||
page_source = driver.page_source
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
async def summarize_memorize_webpage(
|
||||
url: str,
|
||||
text: str,
|
||||
question: str | None,
|
||||
topics_of_interest: list[str],
|
||||
agent: Agent,
|
||||
driver: Optional[WebDriver] = None,
|
||||
) -> str:
|
||||
"""Summarize text using the OpenAI API
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
Args:
|
||||
url (str): The url of the text
|
||||
text (str): The text to summarize
|
||||
question (str): The question to ask the model
|
||||
driver (WebDriver): The webdriver to use to scroll the page
|
||||
hyperlinks = extract_hyperlinks(soup, base_url)
|
||||
|
||||
Returns:
|
||||
str: The summary of the text
|
||||
"""
|
||||
if not text:
|
||||
raise ValueError("No text to summarize")
|
||||
return format_hyperlinks(hyperlinks)
|
||||
|
||||
text_length = len(text)
|
||||
logger.debug(f"Web page content length: {text_length} characters")
|
||||
async def open_page_in_browser(self, url: str, config: Config) -> WebDriver:
|
||||
"""Open a browser window and load a web page using Selenium
|
||||
|
||||
# memory = get_memory(agent.legacy_config)
|
||||
Params:
|
||||
url (str): The URL of the page to load
|
||||
config (Config): The applicable application configuration
|
||||
|
||||
# new_memory = MemoryItem.from_webpage(
|
||||
# content=text,
|
||||
# url=url,
|
||||
# config=agent.legacy_config,
|
||||
# question=question,
|
||||
# )
|
||||
# memory.add(new_memory)
|
||||
Returns:
|
||||
driver (WebDriver): A driver object representing
|
||||
the browser window to scrape
|
||||
"""
|
||||
logging.getLogger("selenium").setLevel(logging.CRITICAL)
|
||||
|
||||
result = None
|
||||
information = None
|
||||
if topics_of_interest:
|
||||
information = await extract_information(
|
||||
text,
|
||||
topics_of_interest=topics_of_interest,
|
||||
llm_provider=agent.llm_provider,
|
||||
config=agent.legacy_config,
|
||||
options_available: dict[str, Type[BrowserOptions]] = {
|
||||
"chrome": ChromeOptions,
|
||||
"edge": EdgeOptions,
|
||||
"firefox": FirefoxOptions,
|
||||
"safari": SafariOptions,
|
||||
}
|
||||
|
||||
options: BrowserOptions = options_available[config.selenium_web_browser]()
|
||||
options.add_argument(f"user-agent={config.user_agent}")
|
||||
|
||||
if isinstance(options, FirefoxOptions):
|
||||
if config.selenium_headless:
|
||||
options.headless = True
|
||||
options.add_argument("--disable-gpu")
|
||||
driver = FirefoxDriver(
|
||||
service=GeckoDriverService(GeckoDriverManager().install()),
|
||||
options=options,
|
||||
)
|
||||
elif isinstance(options, EdgeOptions):
|
||||
driver = EdgeDriver(
|
||||
service=EdgeDriverService(EdgeDriverManager().install()),
|
||||
options=options,
|
||||
)
|
||||
elif isinstance(options, SafariOptions):
|
||||
# Requires a bit more setup on the users end.
|
||||
# See https://developer.apple.com/documentation/webkit/testing_with_webdriver_in_safari # noqa: E501
|
||||
driver = SafariDriver(options=options)
|
||||
elif isinstance(options, ChromeOptions):
|
||||
if platform == "linux" or platform == "linux2":
|
||||
options.add_argument("--disable-dev-shm-usage")
|
||||
options.add_argument("--remote-debugging-port=9222")
|
||||
|
||||
options.add_argument("--no-sandbox")
|
||||
if config.selenium_headless:
|
||||
options.add_argument("--headless=new")
|
||||
options.add_argument("--disable-gpu")
|
||||
|
||||
self._sideload_chrome_extensions(
|
||||
options, config.app_data_dir / "assets" / "crx"
|
||||
)
|
||||
|
||||
if (chromium_driver_path := Path("/usr/bin/chromedriver")).exists():
|
||||
chrome_service = ChromeDriverService(str(chromium_driver_path))
|
||||
else:
|
||||
try:
|
||||
chrome_driver = ChromeDriverManager().install()
|
||||
except AttributeError as e:
|
||||
if "'NoneType' object has no attribute 'split'" in str(e):
|
||||
# https://github.com/SergeyPirogov/webdriver_manager/issues/649
|
||||
logger.critical(
|
||||
"Connecting to browser failed:"
|
||||
" is Chrome or Chromium installed?"
|
||||
)
|
||||
raise
|
||||
chrome_service = ChromeDriverService(chrome_driver)
|
||||
driver = ChromeDriver(service=chrome_service, options=options)
|
||||
|
||||
driver.get(url)
|
||||
|
||||
# Wait for page to be ready, sleep 2 seconds, wait again until page ready.
|
||||
# This allows the cookiewall squasher time to get rid of cookie walls.
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
return "\n".join(f"* {i}" for i in information)
|
||||
else:
|
||||
result, _ = await summarize_text(
|
||||
text,
|
||||
question=question,
|
||||
llm_provider=agent.llm_provider,
|
||||
config=agent.legacy_config,
|
||||
await asyncio.sleep(2)
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
return result
|
||||
|
||||
return driver
|
||||
|
||||
def _sideload_chrome_extensions(
|
||||
self, options: ChromeOptions, dl_folder: Path
|
||||
) -> None:
|
||||
crx_download_url_template = "https://clients2.google.com/service/update2/crx?response=redirect&prodversion=49.0&acceptformat=crx3&x=id%3D{crx_id}%26installsource%3Dondemand%26uc" # noqa
|
||||
cookiewall_squasher_crx_id = "edibdbjcniadpccecjdfdjjppcpchdlm"
|
||||
adblocker_crx_id = "cjpalhdlnbpafiamejdnhcphjbkeiagm"
|
||||
|
||||
# Make sure the target folder exists
|
||||
dl_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for crx_id in (cookiewall_squasher_crx_id, adblocker_crx_id):
|
||||
crx_path = dl_folder / f"{crx_id}.crx"
|
||||
if not crx_path.exists():
|
||||
logger.debug(f"Downloading CRX {crx_id}...")
|
||||
crx_download_url = crx_download_url_template.format(crx_id=crx_id)
|
||||
urlretrieve(crx_download_url, crx_path)
|
||||
logger.debug(f"Downloaded {crx_path.name}")
|
||||
options.add_extension(str(crx_path))
|
||||
|
||||
async def summarize_webpage(
|
||||
self,
|
||||
text: str,
|
||||
question: str | None,
|
||||
topics_of_interest: list[str],
|
||||
) -> str:
|
||||
"""Summarize text using the OpenAI API
|
||||
|
||||
Args:
|
||||
url (str): The url of the text
|
||||
text (str): The text to summarize
|
||||
question (str): The question to ask the model
|
||||
driver (WebDriver): The webdriver to use to scroll the page
|
||||
|
||||
Returns:
|
||||
str: The summary of the text
|
||||
"""
|
||||
if not text:
|
||||
raise ValueError("No text to summarize")
|
||||
|
||||
text_length = len(text)
|
||||
logger.debug(f"Web page content length: {text_length} characters")
|
||||
|
||||
result = None
|
||||
information = None
|
||||
if topics_of_interest:
|
||||
information = await extract_information(
|
||||
text,
|
||||
topics_of_interest=topics_of_interest,
|
||||
llm_provider=self.llm_provider,
|
||||
config=self.legacy_config,
|
||||
)
|
||||
return "\n".join(f"* {i}" for i in information)
|
||||
else:
|
||||
result, _ = await summarize_text(
|
||||
text,
|
||||
question=question,
|
||||
llm_provider=self.llm_provider,
|
||||
config=self.legacy_config,
|
||||
)
|
||||
return result
|
||||
|
||||
82
autogpts/autogpt/autogpt/components/event_history.py
Normal file
82
autogpts/autogpt/autogpt/components/event_history.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import Callable, Generic, Iterator, Optional
|
||||
|
||||
from autogpt.agents.features.watchdog import WatchdogComponent
|
||||
from autogpt.agents.protocols import AfterExecute, AfterParse, MessageProvider
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.resource.model_providers.schema import ChatMessage, ChatModelProvider
|
||||
from autogpt.models.action_history import (
|
||||
AP,
|
||||
ActionResult,
|
||||
Episode,
|
||||
EpisodicActionHistory,
|
||||
)
|
||||
from autogpt.prompts.utils import indent
|
||||
|
||||
|
||||
class EventHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[AP]):
|
||||
"""Keeps track of the event history and provides a summary of the steps."""
|
||||
|
||||
run_after = [WatchdogComponent]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_history: EpisodicActionHistory[AP],
|
||||
max_tokens: int,
|
||||
count_tokens: Callable[[str], int],
|
||||
legacy_config: Config,
|
||||
llm_provider: ChatModelProvider,
|
||||
) -> None:
|
||||
self.event_history = event_history
|
||||
self.max_tokens = max_tokens
|
||||
self.count_tokens = count_tokens
|
||||
self.legacy_config = legacy_config
|
||||
self.llm_provider = llm_provider
|
||||
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
if progress := self._compile_progress(
|
||||
self.event_history.episodes,
|
||||
self.max_tokens,
|
||||
self.count_tokens,
|
||||
):
|
||||
yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}")
|
||||
|
||||
def after_parse(self, result: AP) -> None:
|
||||
self.event_history.register_action(result)
|
||||
|
||||
async def after_execute(self, result: ActionResult) -> None:
|
||||
self.event_history.register_result(result)
|
||||
await self.event_history.handle_compression(
|
||||
self.llm_provider, self.legacy_config
|
||||
)
|
||||
|
||||
def _compile_progress(
|
||||
self,
|
||||
episode_history: list[Episode],
|
||||
max_tokens: Optional[int] = None,
|
||||
count_tokens: Optional[Callable[[str], int]] = None,
|
||||
) -> str:
|
||||
if max_tokens and not count_tokens:
|
||||
raise ValueError("count_tokens is required if max_tokens is set")
|
||||
|
||||
steps: list[str] = []
|
||||
tokens: int = 0
|
||||
n_episodes = len(episode_history)
|
||||
|
||||
for i, episode in enumerate(reversed(episode_history)):
|
||||
# Use full format for the latest 4 steps, summary or format for older steps
|
||||
if i < 4 or episode.summary is None:
|
||||
step_content = indent(episode.format(), 2).strip()
|
||||
else:
|
||||
step_content = episode.summary
|
||||
|
||||
step = f"* Step {n_episodes - i}: {step_content}"
|
||||
|
||||
if max_tokens and count_tokens:
|
||||
step_tokens = count_tokens(step)
|
||||
if tokens + step_tokens > max_tokens:
|
||||
break
|
||||
tokens += step_tokens
|
||||
|
||||
steps.insert(0, step)
|
||||
|
||||
return "\n\n".join(steps)
|
||||
@@ -5,7 +5,7 @@ import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogpt.logs.helpers import request_user_double_check
|
||||
from autogpt.utils import validate_yaml_file
|
||||
from autogpt.utils.utils import validate_yaml_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -7,9 +7,8 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from colorama import Fore
|
||||
from pydantic import Field, SecretStr, validator
|
||||
from pydantic import SecretStr, validator
|
||||
|
||||
import autogpt
|
||||
from autogpt.app.utils import clean_input
|
||||
@@ -18,13 +17,13 @@ from autogpt.core.configuration.schema import (
|
||||
SystemSettings,
|
||||
UserConfigurable,
|
||||
)
|
||||
from autogpt.core.resource.model_providers import CHAT_MODELS, ModelName
|
||||
from autogpt.core.resource.model_providers.openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OpenAICredentials,
|
||||
OpenAIModelName,
|
||||
)
|
||||
from autogpt.file_storage import FileStorageBackendName
|
||||
from autogpt.plugins.plugins_config import PluginsConfig
|
||||
from autogpt.logs.config import LoggingConfig
|
||||
from autogpt.speech import TTSConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -32,7 +31,6 @@ logger = logging.getLogger(__name__)
|
||||
PROJECT_ROOT = Path(autogpt.__file__).parent.parent
|
||||
AI_SETTINGS_FILE = Path("ai_settings.yaml")
|
||||
AZURE_CONFIG_FILE = Path("azure.yaml")
|
||||
PLUGINS_CONFIG_FILE = Path("plugins_config.yaml")
|
||||
PROMPT_SETTINGS_FILE = Path("prompt_settings.yaml")
|
||||
|
||||
GPT_4_MODEL = OpenAIModelName.GPT4
|
||||
@@ -53,11 +51,9 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
authorise_key: str = UserConfigurable(default="y", from_env="AUTHORISE_COMMAND_KEY")
|
||||
exit_key: str = UserConfigurable(default="n", from_env="EXIT_KEY")
|
||||
noninteractive_mode: bool = False
|
||||
chat_messages_enabled: bool = UserConfigurable(
|
||||
default=True, from_env=lambda: os.getenv("CHAT_MESSAGES_ENABLED") == "True"
|
||||
)
|
||||
|
||||
# TTS configuration
|
||||
logging: LoggingConfig = LoggingConfig()
|
||||
tts_config: TTSConfig = TTSConfig()
|
||||
|
||||
# File storage
|
||||
@@ -78,11 +74,11 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
)
|
||||
|
||||
# Model configuration
|
||||
fast_llm: OpenAIModelName = UserConfigurable(
|
||||
fast_llm: ModelName = UserConfigurable(
|
||||
default=OpenAIModelName.GPT3,
|
||||
from_env="FAST_LLM",
|
||||
)
|
||||
smart_llm: OpenAIModelName = UserConfigurable(
|
||||
smart_llm: ModelName = UserConfigurable(
|
||||
default=OpenAIModelName.GPT4_TURBO,
|
||||
from_env="SMART_LLM",
|
||||
)
|
||||
@@ -118,9 +114,9 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
# Commands #
|
||||
############
|
||||
# General
|
||||
disabled_command_categories: list[str] = UserConfigurable(
|
||||
disabled_commands: list[str] = UserConfigurable(
|
||||
default_factory=list,
|
||||
from_env=lambda: _safe_split(os.getenv("DISABLED_COMMAND_CATEGORIES")),
|
||||
from_env=lambda: _safe_split(os.getenv("DISABLED_COMMANDS")),
|
||||
)
|
||||
|
||||
# File ops
|
||||
@@ -179,29 +175,6 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
from_env="USER_AGENT",
|
||||
)
|
||||
|
||||
###################
|
||||
# Plugin Settings #
|
||||
###################
|
||||
plugins_dir: str = UserConfigurable("plugins", from_env="PLUGINS_DIR")
|
||||
plugins_config_file: Path = UserConfigurable(
|
||||
default=PLUGINS_CONFIG_FILE, from_env="PLUGINS_CONFIG_FILE"
|
||||
)
|
||||
plugins_config: PluginsConfig = Field(
|
||||
default_factory=lambda: PluginsConfig(plugins={})
|
||||
)
|
||||
plugins: list[AutoGPTPluginTemplate] = Field(default_factory=list, exclude=True)
|
||||
plugins_allowlist: list[str] = UserConfigurable(
|
||||
default_factory=list,
|
||||
from_env=lambda: _safe_split(os.getenv("ALLOWLISTED_PLUGINS")),
|
||||
)
|
||||
plugins_denylist: list[str] = UserConfigurable(
|
||||
default_factory=list,
|
||||
from_env=lambda: _safe_split(os.getenv("DENYLISTED_PLUGINS")),
|
||||
)
|
||||
plugins_openai: list[str] = UserConfigurable(
|
||||
default_factory=list, from_env=lambda: _safe_split(os.getenv("OPENAI_PLUGINS"))
|
||||
)
|
||||
|
||||
###############
|
||||
# Credentials #
|
||||
###############
|
||||
@@ -229,22 +202,12 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
# Stable Diffusion
|
||||
sd_webui_auth: Optional[str] = UserConfigurable(from_env="SD_WEBUI_AUTH")
|
||||
|
||||
@validator("plugins", each_item=True)
|
||||
def validate_plugins(cls, p: AutoGPTPluginTemplate | Any):
|
||||
assert issubclass(
|
||||
p.__class__, AutoGPTPluginTemplate
|
||||
), f"{p} does not subclass AutoGPTPluginTemplate"
|
||||
assert (
|
||||
p.__class__.__name__ != "AutoGPTPluginTemplate"
|
||||
), f"Plugins must subclass AutoGPTPluginTemplate; {p} is a template instance"
|
||||
return p
|
||||
|
||||
@validator("openai_functions")
|
||||
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
|
||||
if v:
|
||||
smart_llm = values["smart_llm"]
|
||||
assert OPEN_AI_CHAT_MODELS[smart_llm].has_function_call_api, (
|
||||
f"Model {smart_llm} does not support OpenAI Functions. "
|
||||
assert CHAT_MODELS[smart_llm].has_function_call_api, (
|
||||
f"Model {smart_llm} does not support tool calling. "
|
||||
"Please disable OPENAI_FUNCTIONS or choose a suitable model."
|
||||
)
|
||||
return v
|
||||
@@ -264,7 +227,6 @@ class ConfigBuilder(Configurable[Config]):
|
||||
for k in {
|
||||
"ai_settings_file", # TODO: deprecate or repurpose
|
||||
"prompt_settings_file", # TODO: deprecate or repurpose
|
||||
"plugins_config_file", # TODO: move from project root
|
||||
"azure_config_file", # TODO: move from project root
|
||||
}:
|
||||
setattr(config, k, project_root / getattr(config, k))
|
||||
@@ -276,18 +238,12 @@ class ConfigBuilder(Configurable[Config]):
|
||||
):
|
||||
config.openai_credentials.load_azure_config(config_file)
|
||||
|
||||
config.plugins_config = PluginsConfig.load_config(
|
||||
config.plugins_config_file,
|
||||
config.plugins_denylist,
|
||||
config.plugins_allowlist,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def assert_config_has_openai_api_key(config: Config) -> None:
|
||||
"""Check if the OpenAI API key is set in config.py or as an environment variable."""
|
||||
key_pattern = r"^sk-\w{48}"
|
||||
key_pattern = r"^sk-(proj-)?\w{48}"
|
||||
openai_api_key = (
|
||||
config.openai_credentials.api_key.get_secret_value()
|
||||
if config.openai_credentials
|
||||
|
||||
@@ -24,6 +24,7 @@ class LanguageModelClassification(str, enum.Enum):
|
||||
class ChatPrompt(BaseModel):
|
||||
messages: list[ChatMessage]
|
||||
functions: list[CompletionModelFunction] = Field(default_factory=list)
|
||||
prefill_response: str = ""
|
||||
|
||||
def raw(self) -> list[ChatMessageDict]:
|
||||
return [m.dict() for m in self.messages]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .multi import CHAT_MODELS, ModelName, MultiProvider
|
||||
from .openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OPEN_AI_EMBEDDING_MODELS,
|
||||
@@ -42,11 +43,13 @@ __all__ = [
|
||||
"ChatModelProvider",
|
||||
"ChatModelResponse",
|
||||
"CompletionModelFunction",
|
||||
"CHAT_MODELS",
|
||||
"Embedding",
|
||||
"EmbeddingModelInfo",
|
||||
"EmbeddingModelProvider",
|
||||
"EmbeddingModelResponse",
|
||||
"ModelInfo",
|
||||
"ModelName",
|
||||
"ModelProvider",
|
||||
"ModelProviderBudget",
|
||||
"ModelProviderCredentials",
|
||||
@@ -56,6 +59,7 @@ __all__ = [
|
||||
"ModelProviderUsage",
|
||||
"ModelResponse",
|
||||
"ModelTokenizer",
|
||||
"MultiProvider",
|
||||
"OPEN_AI_MODELS",
|
||||
"OPEN_AI_CHAT_MODELS",
|
||||
"OPEN_AI_EMBEDDING_MODELS",
|
||||
|
||||
@@ -0,0 +1,495 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
import tiktoken
|
||||
from anthropic import APIConnectionError, APIStatusError
|
||||
from pydantic import SecretStr
|
||||
|
||||
from autogpt.core.configuration import Configurable, UserConfigurable
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
AssistantToolCall,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderCredentials,
|
||||
ModelProviderName,
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
ToolResultMessage,
|
||||
)
|
||||
|
||||
from .utils import validate_tool_calls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anthropic.types.beta.tools import MessageCreateParams
|
||||
from anthropic.types.beta.tools import ToolsBetaMessage as Message
|
||||
from anthropic.types.beta.tools import ToolsBetaMessageParam as MessageParam
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
class AnthropicModelName(str, enum.Enum):
|
||||
CLAUDE3_OPUS_v1 = "claude-3-opus-20240229"
|
||||
CLAUDE3_SONNET_v1 = "claude-3-sonnet-20240229"
|
||||
CLAUDE3_HAIKU_v1 = "claude-3-haiku-20240307"
|
||||
|
||||
|
||||
ANTHROPIC_CHAT_MODELS = {
|
||||
info.name: info
|
||||
for info in [
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_OPUS_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=15 / 1e6,
|
||||
completion_token_cost=75 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_SONNET_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=3 / 1e6,
|
||||
completion_token_cost=15 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_HAIKU_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=0.25 / 1e6,
|
||||
completion_token_cost=1.25 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class AnthropicConfiguration(ModelProviderConfiguration):
|
||||
fix_failed_parse_tries: int = UserConfigurable(3)
|
||||
|
||||
|
||||
class AnthropicCredentials(ModelProviderCredentials):
|
||||
"""Credentials for Anthropic."""
|
||||
|
||||
api_key: SecretStr = UserConfigurable(from_env="ANTHROPIC_API_KEY")
|
||||
api_base: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="ANTHROPIC_API_BASE_URL"
|
||||
)
|
||||
|
||||
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||
return {
|
||||
k: (v.get_secret_value() if type(v) is SecretStr else v)
|
||||
for k, v in {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.api_base,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
|
||||
class AnthropicSettings(ModelProviderSettings):
|
||||
configuration: AnthropicConfiguration
|
||||
credentials: Optional[AnthropicCredentials]
|
||||
budget: ModelProviderBudget
|
||||
|
||||
|
||||
class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
default_settings = AnthropicSettings(
|
||||
name="anthropic_provider",
|
||||
description="Provides access to Anthropic's API.",
|
||||
configuration=AnthropicConfiguration(
|
||||
retries_per_request=7,
|
||||
),
|
||||
credentials=None,
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_settings: AnthropicSettings
|
||||
_configuration: AnthropicConfiguration
|
||||
_credentials: AnthropicCredentials
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[AnthropicSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not settings:
|
||||
settings = self.default_settings.copy(deep=True)
|
||||
if not settings.credentials:
|
||||
settings.credentials = AnthropicCredentials.from_env()
|
||||
|
||||
super(AnthropicProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
self._client = AsyncAnthropic(**self._credentials.get_api_access_kwargs())
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
return list(ANTHROPIC_CHAT_MODELS.values())
|
||||
|
||||
def get_token_limit(self, model_name: str) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return ANTHROPIC_CHAT_MODELS[model_name].max_tokens
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, model_name: AnthropicModelName) -> ModelTokenizer:
|
||||
# HACK: No official tokenizer is available for Claude 3
|
||||
return tiktoken.encoding_for_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def count_tokens(cls, text: str, model_name: AnthropicModelName) -> int:
|
||||
return 0 # HACK: No official tokenizer is available for Claude 3
|
||||
|
||||
@classmethod
|
||||
def count_message_tokens(
|
||||
cls,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: AnthropicModelName,
|
||||
) -> int:
|
||||
return 0 # HACK: No official tokenizer is available for Claude 3
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: AnthropicModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the Anthropic API."""
|
||||
anthropic_messages, completion_kwargs = self._get_chat_completion_args(
|
||||
prompt_messages=model_prompt,
|
||||
model=model_name,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
total_cost = 0.0
|
||||
attempts = 0
|
||||
while True:
|
||||
completion_kwargs["messages"] = anthropic_messages.copy()
|
||||
if prefill_response:
|
||||
completion_kwargs["messages"].append(
|
||||
{"role": "assistant", "content": prefill_response}
|
||||
)
|
||||
|
||||
(
|
||||
_assistant_msg,
|
||||
cost,
|
||||
t_input,
|
||||
t_output,
|
||||
) = await self._create_chat_completion(completion_kwargs)
|
||||
total_cost += cost
|
||||
self._logger.debug(
|
||||
f"Completion usage: {t_input} input, {t_output} output "
|
||||
f"- ${round(cost, 5)}"
|
||||
)
|
||||
|
||||
# Merge prefill into generated response
|
||||
if prefill_response:
|
||||
first_text_block = next(
|
||||
b for b in _assistant_msg.content if b.type == "text"
|
||||
)
|
||||
first_text_block.text = prefill_response + first_text_block.text
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content="\n\n".join(
|
||||
b.text for b in _assistant_msg.content if b.type == "text"
|
||||
),
|
||||
tool_calls=self._parse_assistant_tool_calls(_assistant_msg),
|
||||
)
|
||||
|
||||
# If parsing the response fails, append the error to the prompt, and let the
|
||||
# LLM fix its mistake(s).
|
||||
attempts += 1
|
||||
tool_call_errors = []
|
||||
try:
|
||||
# Validate tool calls
|
||||
if assistant_msg.tool_calls and functions:
|
||||
tool_call_errors = validate_tool_calls(
|
||||
assistant_msg.tool_calls, functions
|
||||
)
|
||||
if tool_call_errors:
|
||||
raise ValueError(
|
||||
"Invalid tool use(s):\n"
|
||||
+ "\n".join(str(e) for e in tool_call_errors)
|
||||
)
|
||||
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.debug(
|
||||
f"Parsing failed on response: '''{_assistant_msg}'''"
|
||||
)
|
||||
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
|
||||
sentry_sdk.capture_exception(
|
||||
error=e,
|
||||
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
|
||||
)
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
anthropic_messages.append(
|
||||
_assistant_msg.dict(include={"role", "content"})
|
||||
)
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*(
|
||||
# tool_result is required if last assistant message
|
||||
# had tool_use block(s)
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tc.id,
|
||||
"is_error": True,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Not executed because parsing "
|
||||
"of your last message failed"
|
||||
if not tool_call_errors
|
||||
else str(e)
|
||||
if (
|
||||
e := next(
|
||||
(
|
||||
tce
|
||||
for tce in tool_call_errors
|
||||
if tce.name
|
||||
== tc.function.name
|
||||
),
|
||||
None,
|
||||
)
|
||||
)
|
||||
else "Not executed because validation "
|
||||
"of tool input failed",
|
||||
}
|
||||
],
|
||||
}
|
||||
for tc in assistant_msg.tool_calls or []
|
||||
),
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
"ERROR PARSING YOUR RESPONSE:\n\n"
|
||||
f"{e.__class__.__name__}: {e}"
|
||||
),
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
if attempts > 1:
|
||||
self._logger.debug(
|
||||
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
|
||||
)
|
||||
|
||||
return ChatModelResponse(
|
||||
response=assistant_msg,
|
||||
parsed_result=parsed_result,
|
||||
model_info=ANTHROPIC_CHAT_MODELS[model_name],
|
||||
prompt_tokens_used=t_input,
|
||||
completion_tokens_used=t_output,
|
||||
)
|
||||
|
||||
def _get_chat_completion_args(
|
||||
self,
|
||||
prompt_messages: list[ChatMessage],
|
||||
model: AnthropicModelName,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> tuple[list[MessageParam], MessageCreateParams]:
|
||||
"""Prepare arguments for message completion API call.
|
||||
|
||||
Args:
|
||||
prompt_messages: List of ChatMessages.
|
||||
model: The model to use.
|
||||
functions: Optional list of functions available to the LLM.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
list[MessageParam]: Prompt messages for the Anthropic call
|
||||
dict[str, Any]: Any other kwargs for the Anthropic call
|
||||
"""
|
||||
kwargs["model"] = model
|
||||
|
||||
if functions:
|
||||
kwargs["tools"] = [
|
||||
{
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
name: param.to_dict()
|
||||
for name, param in f.parameters.items()
|
||||
},
|
||||
"required": [
|
||||
name
|
||||
for name, param in f.parameters.items()
|
||||
if param.required
|
||||
],
|
||||
},
|
||||
}
|
||||
for f in functions
|
||||
]
|
||||
|
||||
kwargs["max_tokens"] = max_output_tokens or 4096
|
||||
|
||||
if extra_headers := self._configuration.extra_request_headers:
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
|
||||
kwargs["extra_headers"].update(extra_headers.copy())
|
||||
|
||||
system_messages = [
|
||||
m for m in prompt_messages if m.role == ChatMessage.Role.SYSTEM
|
||||
]
|
||||
if (_n := len(system_messages)) > 1:
|
||||
self._logger.warning(
|
||||
f"Prompt has {_n} system messages; Anthropic supports only 1. "
|
||||
"They will be merged, and removed from the rest of the prompt."
|
||||
)
|
||||
kwargs["system"] = "\n\n".join(sm.content for sm in system_messages)
|
||||
|
||||
messages: list[MessageParam] = []
|
||||
for message in prompt_messages:
|
||||
if message.role == ChatMessage.Role.SYSTEM:
|
||||
continue
|
||||
elif message.role == ChatMessage.Role.USER:
|
||||
# Merge subsequent user messages
|
||||
if messages and (prev_msg := messages[-1])["role"] == "user":
|
||||
if isinstance(prev_msg["content"], str):
|
||||
prev_msg["content"] += f"\n\n{message.content}"
|
||||
else:
|
||||
assert isinstance(prev_msg["content"], list)
|
||||
prev_msg["content"].append(
|
||||
{"type": "text", "text": message.content}
|
||||
)
|
||||
else:
|
||||
messages.append({"role": "user", "content": message.content})
|
||||
# TODO: add support for image blocks
|
||||
elif message.role == ChatMessage.Role.ASSISTANT:
|
||||
if isinstance(message, AssistantChatMessage) and message.tool_calls:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
*(
|
||||
[{"type": "text", "text": message.content}]
|
||||
if message.content
|
||||
else []
|
||||
),
|
||||
*(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"input": tc.function.arguments,
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
),
|
||||
],
|
||||
}
|
||||
)
|
||||
elif message.content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
}
|
||||
)
|
||||
elif isinstance(message, ToolResultMessage):
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": [{"type": "text", "text": message.content}],
|
||||
"is_error": message.is_error,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return messages, kwargs # type: ignore
|
||||
|
||||
async def _create_chat_completion(
|
||||
self, completion_kwargs: MessageCreateParams
|
||||
) -> tuple[Message, float, int, int]:
|
||||
"""
|
||||
Create a chat completion using the Anthropic API with retry handling.
|
||||
|
||||
Params:
|
||||
completion_kwargs: Keyword arguments for an Anthropic Messages API call
|
||||
|
||||
Returns:
|
||||
Message: The message completion object
|
||||
float: The cost ($) of this completion
|
||||
int: Number of input tokens used
|
||||
int: Number of output tokens used
|
||||
"""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_chat_completion_with_retry(
|
||||
completion_kwargs: MessageCreateParams,
|
||||
) -> Message:
|
||||
return await self._client.beta.tools.messages.create(
|
||||
**completion_kwargs # type: ignore
|
||||
)
|
||||
|
||||
response = await _create_chat_completion_with_retry(completion_kwargs)
|
||||
|
||||
cost = self._budget.update_usage_and_cost(
|
||||
model_info=ANTHROPIC_CHAT_MODELS[completion_kwargs["model"]],
|
||||
input_tokens_used=response.usage.input_tokens,
|
||||
output_tokens_used=response.usage.output_tokens,
|
||||
)
|
||||
return response, cost, response.usage.input_tokens, response.usage.output_tokens
|
||||
|
||||
def _parse_assistant_tool_calls(
|
||||
self, assistant_message: Message
|
||||
) -> list[AssistantToolCall]:
|
||||
return [
|
||||
AssistantToolCall(
|
||||
id=c.id,
|
||||
type="function",
|
||||
function=AssistantFunctionCall(name=c.name, arguments=c.input),
|
||||
)
|
||||
for c in assistant_message.content
|
||||
if c.type == "tool_use"
|
||||
]
|
||||
|
||||
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
return tenacity.retry(
|
||||
retry=(
|
||||
tenacity.retry_if_exception_type(APIConnectionError)
|
||||
| tenacity.retry_if_exception(
|
||||
lambda e: isinstance(e, APIStatusError) and e.status_code >= 500
|
||||
)
|
||||
),
|
||||
wait=tenacity.wait_exponential(),
|
||||
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
|
||||
after=tenacity.after_log(self._logger, logging.DEBUG),
|
||||
)(func)
|
||||
|
||||
def __repr__(self):
|
||||
return "AnthropicProvider()"
|
||||
162
autogpts/autogpt/autogpt/core/resource/model_providers/multi.py
Normal file
162
autogpts/autogpt/autogpt/core/resource/model_providers/multi.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Iterator, Optional, TypeVar
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from autogpt.core.configuration import Configurable
|
||||
|
||||
from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider
|
||||
from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider
|
||||
from .schema import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderName,
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
ModelName = AnthropicModelName | OpenAIModelName
|
||||
|
||||
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
|
||||
|
||||
|
||||
class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
|
||||
default_settings = ModelProviderSettings(
|
||||
name="multi_provider",
|
||||
description=(
|
||||
"Provides access to all of the available models, regardless of provider."
|
||||
),
|
||||
configuration=ModelProviderConfiguration(
|
||||
retries_per_request=7,
|
||||
),
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
_provider_instances: dict[ModelProviderName, ChatModelProvider]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
super(MultiProvider, self).__init__(settings=settings, logger=logger)
|
||||
self._budget = self._settings.budget or ModelProviderBudget()
|
||||
|
||||
self._provider_instances = {}
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
models = []
|
||||
for provider in self.get_available_providers():
|
||||
models.extend(await provider.get_available_models())
|
||||
return models
|
||||
|
||||
def get_token_limit(self, model_name: ModelName) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return self.get_model_provider(model_name).get_token_limit(model_name)
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, model_name: ModelName) -> ModelTokenizer:
|
||||
return cls._get_model_provider_class(model_name).get_tokenizer(model_name)
|
||||
|
||||
@classmethod
|
||||
def count_tokens(cls, text: str, model_name: ModelName) -> int:
|
||||
return cls._get_model_provider_class(model_name).count_tokens(
|
||||
text=text, model_name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def count_message_tokens(
|
||||
cls, messages: ChatMessage | list[ChatMessage], model_name: ModelName
|
||||
) -> int:
|
||||
return cls._get_model_provider_class(model_name).count_message_tokens(
|
||||
messages=messages, model_name=model_name
|
||||
)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the Anthropic API."""
|
||||
return await self.get_model_provider(model_name).create_chat_completion(
|
||||
model_prompt=model_prompt,
|
||||
model_name=model_name,
|
||||
completion_parser=completion_parser,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
prefill_response=prefill_response,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_model_provider(self, model: ModelName) -> ChatModelProvider:
|
||||
model_info = CHAT_MODELS[model]
|
||||
return self._get_provider(model_info.provider_name)
|
||||
|
||||
def get_available_providers(self) -> Iterator[ChatModelProvider]:
|
||||
for provider_name in ModelProviderName:
|
||||
try:
|
||||
yield self._get_provider(provider_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_provider(self, provider_name: ModelProviderName) -> ChatModelProvider:
|
||||
_provider = self._provider_instances.get(provider_name)
|
||||
if not _provider:
|
||||
Provider = self._get_provider_class(provider_name)
|
||||
settings = Provider.default_settings.copy(deep=True)
|
||||
settings.budget = self._budget
|
||||
settings.configuration.extra_request_headers.update(
|
||||
self._settings.configuration.extra_request_headers
|
||||
)
|
||||
if settings.credentials is None:
|
||||
try:
|
||||
Credentials = settings.__fields__["credentials"].type_
|
||||
settings.credentials = Credentials.from_env()
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
f"{provider_name} is unavailable: can't load credentials"
|
||||
) from e
|
||||
|
||||
self._provider_instances[provider_name] = _provider = Provider(
|
||||
settings=settings, logger=self._logger
|
||||
)
|
||||
_provider._budget = self._budget # Object binding not preserved by Pydantic
|
||||
return _provider
|
||||
|
||||
@classmethod
|
||||
def _get_model_provider_class(
|
||||
cls, model_name: ModelName
|
||||
) -> type[AnthropicProvider | OpenAIProvider]:
|
||||
return cls._get_provider_class(CHAT_MODELS[model_name].provider_name)
|
||||
|
||||
@classmethod
|
||||
def _get_provider_class(
|
||||
cls, provider_name: ModelProviderName
|
||||
) -> type[AnthropicProvider | OpenAIProvider]:
|
||||
try:
|
||||
return {
|
||||
ModelProviderName.ANTHROPIC: AnthropicProvider,
|
||||
ModelProviderName.OPENAI: OpenAIProvider,
|
||||
}[provider_name]
|
||||
except KeyError:
|
||||
raise ValueError(f"{provider_name} is not a known provider") from None
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
@@ -42,6 +42,8 @@ from autogpt.core.resource.model_providers.schema import (
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
from .utils import validate_tool_calls
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
@@ -70,9 +72,11 @@ class OpenAIModelName(str, enum.Enum):
|
||||
GPT4_v3 = "gpt-4-1106-preview"
|
||||
GPT4_v3_VISION = "gpt-4-1106-vision-preview"
|
||||
GPT4_v4 = "gpt-4-0125-preview"
|
||||
GPT4_v5 = "gpt-4-turbo-2024-04-09"
|
||||
GPT4_ROLLING = "gpt-4"
|
||||
GPT4_ROLLING_32k = "gpt-4-32k"
|
||||
GPT4_TURBO = "gpt-4-turbo-preview"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT4_TURBO_PREVIEW = "gpt-4-turbo-preview"
|
||||
GPT4_VISION = "gpt-4-vision-preview"
|
||||
GPT4 = GPT4_ROLLING
|
||||
GPT4_32k = GPT4_ROLLING_32k
|
||||
@@ -180,8 +184,10 @@ chat_model_mapping = {
|
||||
OpenAIModelName.GPT4_TURBO: [
|
||||
OpenAIModelName.GPT4_v3,
|
||||
OpenAIModelName.GPT4_v3_VISION,
|
||||
OpenAIModelName.GPT4_v4,
|
||||
OpenAIModelName.GPT4_VISION,
|
||||
OpenAIModelName.GPT4_v4,
|
||||
OpenAIModelName.GPT4_TURBO_PREVIEW,
|
||||
OpenAIModelName.GPT4_v5,
|
||||
],
|
||||
}
|
||||
for base, copies in chat_model_mapping.items():
|
||||
@@ -294,6 +300,7 @@ class OpenAIProvider(
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_settings: OpenAISettings
|
||||
_configuration: OpenAIConfiguration
|
||||
_credentials: OpenAICredentials
|
||||
_budget: ModelProviderBudget
|
||||
@@ -308,11 +315,7 @@ class OpenAIProvider(
|
||||
if not settings.credentials:
|
||||
settings.credentials = OpenAICredentials.from_env()
|
||||
|
||||
self._settings = settings
|
||||
|
||||
self._configuration = settings.configuration
|
||||
self._credentials = settings.credentials
|
||||
self._budget = settings.budget
|
||||
super(OpenAIProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
if self._credentials.api_type == "azure":
|
||||
from openai import AsyncAzureOpenAI
|
||||
@@ -325,8 +328,6 @@ class OpenAIProvider(
|
||||
|
||||
self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs())
|
||||
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
_models = (await self._client.models.list()).data
|
||||
return [OPEN_AI_MODELS[m.id] for m in _models if m.id in OPEN_AI_MODELS]
|
||||
@@ -394,9 +395,10 @@ class OpenAIProvider(
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "", # not supported by OpenAI
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the OpenAI API."""
|
||||
"""Create a completion using the OpenAI API and parse it."""
|
||||
|
||||
openai_messages, completion_kwargs = self._get_chat_completion_args(
|
||||
model_prompt=model_prompt,
|
||||
@@ -428,6 +430,10 @@ class OpenAIProvider(
|
||||
)
|
||||
parse_errors += _errors
|
||||
|
||||
# Validate tool calls
|
||||
if not parse_errors and tool_calls and functions:
|
||||
parse_errors += validate_tool_calls(tool_calls, functions)
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
tool_calls=tool_calls or None,
|
||||
@@ -461,8 +467,11 @@ class OpenAIProvider(
|
||||
self._logger.debug(
|
||||
f"Parsing failed on response: '''{_assistant_msg}'''"
|
||||
)
|
||||
parse_errors_fmt = "\n\n".join(
|
||||
f"{e.__class__.__name__}: {e}" for e in parse_errors
|
||||
)
|
||||
self._logger.warning(
|
||||
f"Parsing attempt #{attempts} failed: {parse_errors}"
|
||||
f"Parsing attempt #{attempts} failed: {parse_errors_fmt}"
|
||||
)
|
||||
for e in parse_errors:
|
||||
sentry_sdk.capture_exception(
|
||||
@@ -476,10 +485,7 @@ class OpenAIProvider(
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"ERROR PARSING YOUR RESPONSE:\n\n"
|
||||
+ "\n\n".join(
|
||||
f"{e.__class__.__name__}: {e}" for e in parse_errors
|
||||
)
|
||||
f"ERROR PARSING YOUR RESPONSE:\n\n{parse_errors_fmt}"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import abc
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
@@ -26,6 +28,10 @@ from autogpt.core.resource.schema import (
|
||||
ResourceType,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.logs.utils import fmt_kwargs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jsonschema import ValidationError
|
||||
|
||||
|
||||
class ModelProviderService(str, enum.Enum):
|
||||
@@ -38,6 +44,7 @@ class ModelProviderService(str, enum.Enum):
|
||||
|
||||
class ModelProviderName(str, enum.Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -72,6 +79,9 @@ class AssistantFunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}({fmt_kwargs(self.arguments)})"
|
||||
|
||||
|
||||
class AssistantFunctionCallDict(TypedDict):
|
||||
name: str
|
||||
@@ -96,6 +106,12 @@ class AssistantChatMessage(ChatMessage):
|
||||
tool_calls: Optional[list[AssistantToolCall]] = None
|
||||
|
||||
|
||||
class ToolResultMessage(ChatMessage):
|
||||
role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL
|
||||
is_error: bool = False
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AssistantChatMessageDict(TypedDict, total=False):
|
||||
role: str
|
||||
content: str
|
||||
@@ -142,6 +158,30 @@ class CompletionModelFunction(BaseModel):
|
||||
)
|
||||
return f"{self.name}: {self.description}. Params: ({params})"
|
||||
|
||||
def validate_call(
|
||||
self, function_call: AssistantFunctionCall
|
||||
) -> tuple[bool, list["ValidationError"]]:
|
||||
"""
|
||||
Validates the given function call against the function's parameter specs
|
||||
|
||||
Returns:
|
||||
bool: Whether the given set of arguments is valid for this command
|
||||
list[ValidationError]: Issues with the set of arguments (if any)
|
||||
|
||||
Raises:
|
||||
ValueError: If the function_call doesn't call this function
|
||||
"""
|
||||
if function_call.name != self.name:
|
||||
raise ValueError(
|
||||
f"Can't validate {function_call.name} call using {self.name} spec"
|
||||
)
|
||||
|
||||
params_schema = JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={name: spec for name, spec in self.parameters.items()},
|
||||
)
|
||||
return params_schema.validate_object(function_call.arguments)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Struct for model information.
|
||||
@@ -225,7 +265,7 @@ class ModelProviderBudget(ProviderBudget):
|
||||
class ModelProviderSettings(ProviderSettings):
|
||||
resource_type: ResourceType = ResourceType.MODEL
|
||||
configuration: ModelProviderConfiguration
|
||||
credentials: ModelProviderCredentials
|
||||
credentials: Optional[ModelProviderCredentials] = None
|
||||
budget: Optional[ModelProviderBudget] = None
|
||||
|
||||
|
||||
@@ -234,9 +274,28 @@ class ModelProvider(abc.ABC):
|
||||
|
||||
default_settings: ClassVar[ModelProviderSettings]
|
||||
|
||||
_settings: ModelProviderSettings
|
||||
_configuration: ModelProviderConfiguration
|
||||
_credentials: Optional[ModelProviderCredentials] = None
|
||||
_budget: Optional[ModelProviderBudget] = None
|
||||
|
||||
_logger: logging.Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not settings:
|
||||
settings = self.default_settings.copy(deep=True)
|
||||
|
||||
self._settings = settings
|
||||
self._configuration = settings.configuration
|
||||
self._credentials = settings.credentials
|
||||
self._budget = settings.budget
|
||||
|
||||
self._logger = logger or logging.getLogger(self.__module__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
...
|
||||
@@ -354,6 +413,7 @@ class ChatModelProvider(ModelProvider):
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
...
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
from typing import Any
|
||||
|
||||
from .schema import AssistantToolCall, CompletionModelFunction
|
||||
|
||||
|
||||
class InvalidFunctionCallError(Exception):
|
||||
def __init__(self, name: str, arguments: dict[str, Any], message: str):
|
||||
self.message = message
|
||||
self.name = name
|
||||
self.arguments = arguments
|
||||
super().__init__(message)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Invalid function call for {self.name}: {self.message}"
|
||||
|
||||
|
||||
def validate_tool_calls(
|
||||
tool_calls: list[AssistantToolCall], functions: list[CompletionModelFunction]
|
||||
) -> list[InvalidFunctionCallError]:
|
||||
"""
|
||||
Validates a list of tool calls against a list of functions.
|
||||
|
||||
1. Tries to find a function matching each tool call
|
||||
2. If a matching function is found, validates the tool call's arguments,
|
||||
reporting any resulting errors
|
||||
2. If no matching function is found, an error "Unknown function X" is reported
|
||||
3. A list of all errors encountered during validation is returned
|
||||
|
||||
Params:
|
||||
tool_calls: A list of tool calls to validate.
|
||||
functions: A list of functions to validate against.
|
||||
|
||||
Returns:
|
||||
list[InvalidFunctionCallError]: All errors encountered during validation.
|
||||
"""
|
||||
errors: list[InvalidFunctionCallError] = []
|
||||
for tool_call in tool_calls:
|
||||
function_call = tool_call.function
|
||||
|
||||
if function := next(
|
||||
(f for f in functions if f.name == function_call.name),
|
||||
None,
|
||||
):
|
||||
is_valid, validation_errors = function.validate_call(function_call)
|
||||
if not is_valid:
|
||||
fmt_errors = [
|
||||
f"{'.'.join(str(p) for p in f.path)}: {f.message}"
|
||||
if f.path
|
||||
else f.message
|
||||
for f in validation_errors
|
||||
]
|
||||
errors.append(
|
||||
InvalidFunctionCallError(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
message=(
|
||||
"The set of arguments supplied is invalid:\n"
|
||||
+ "\n".join(fmt_errors)
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
errors.append(
|
||||
InvalidFunctionCallError(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
message=f"Unknown function {function_call.name}",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
@@ -4,10 +4,8 @@ from agent_protocol import StepHandler, StepResult
|
||||
|
||||
from autogpt.agents import Agent
|
||||
from autogpt.app.main import UserFeedback
|
||||
from autogpt.commands import COMMAND_CATEGORIES
|
||||
from autogpt.config import AIProfile, ConfigBuilder
|
||||
from autogpt.logs.helpers import user_friendly_output
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT
|
||||
|
||||
|
||||
@@ -82,7 +80,6 @@ def bootstrap_agent(task, continuous_mode) -> Agent:
|
||||
config.logging.plain_console_output = True
|
||||
config.continuous_mode = continuous_mode
|
||||
config.temperature = 0
|
||||
command_registry = CommandRegistry.with_command_modules(COMMAND_CATEGORIES, config)
|
||||
config.memory_backend = "no_memory"
|
||||
ai_profile = AIProfile(
|
||||
ai_name="AutoGPT",
|
||||
@@ -92,7 +89,6 @@ def bootstrap_agent(task, continuous_mode) -> Agent:
|
||||
# FIXME this won't work - ai_profile and triggering_prompt is not a valid argument,
|
||||
# lacks file_storage, settings and llm_provider
|
||||
return Agent(
|
||||
command_registry=command_registry,
|
||||
ai_profile=ai_profile,
|
||||
legacy_config=config,
|
||||
triggering_prompt=DEFAULT_TRIGGERING_PROMPT,
|
||||
|
||||
@@ -3,21 +3,25 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.core.prompting import ChatPrompt
|
||||
from autogpt.core.resource.model_providers import ChatMessage
|
||||
|
||||
SEPARATOR_LENGTH = 42
|
||||
|
||||
|
||||
def dump_prompt(prompt: "ChatPrompt") -> str:
|
||||
def dump_prompt(prompt: "ChatPrompt | list[ChatMessage]") -> str:
|
||||
def separator(text: str):
|
||||
half_sep_len = (SEPARATOR_LENGTH - 2 - len(text)) / 2
|
||||
return f"{floor(half_sep_len)*'-'} {text.upper()} {ceil(half_sep_len)*'-'}"
|
||||
|
||||
if not isinstance(prompt, list):
|
||||
prompt = prompt.messages
|
||||
|
||||
formatted_messages = "\n".join(
|
||||
[f"{separator(m.role)}\n{m.content}" for m in prompt.messages]
|
||||
[f"{separator(m.role)}\n{m.content}" for m in prompt]
|
||||
)
|
||||
return f"""
|
||||
============== {prompt.__class__.__name__} ==============
|
||||
Length: {len(prompt.messages)} messages
|
||||
Length: {len(prompt)} messages
|
||||
{formatted_messages}
|
||||
==========================================
|
||||
"""
|
||||
|
||||
@@ -57,10 +57,35 @@ class JSONSchema(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def from_dict(schema: dict) -> "JSONSchema":
|
||||
def resolve_references(schema: dict, definitions: dict) -> dict:
|
||||
"""
|
||||
Recursively resolve type $refs in the JSON schema with their definitions.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_path = schema["$ref"].split("/")[
|
||||
2:
|
||||
] # Split and remove '#/definitions'
|
||||
ref_value = definitions
|
||||
for key in ref_path:
|
||||
ref_value = ref_value[key]
|
||||
return resolve_references(ref_value, definitions)
|
||||
else:
|
||||
return {
|
||||
k: resolve_references(v, definitions) for k, v in schema.items()
|
||||
}
|
||||
elif isinstance(schema, list):
|
||||
return [resolve_references(item, definitions) for item in schema]
|
||||
else:
|
||||
return schema
|
||||
|
||||
definitions = schema.get("definitions", {})
|
||||
schema = resolve_references(schema, definitions)
|
||||
|
||||
return JSONSchema(
|
||||
description=schema.get("description"),
|
||||
type=schema["type"],
|
||||
enum=schema["enum"] if "enum" in schema else None,
|
||||
enum=schema.get("enum"),
|
||||
items=JSONSchema.from_dict(schema["items"]) if "items" in schema else None,
|
||||
properties=JSONSchema.parse_properties(schema)
|
||||
if schema["type"] == "object"
|
||||
|
||||
@@ -39,7 +39,7 @@ def json_loads(json_str: str) -> Any:
|
||||
"JSON parse errors:\n" + "\n".join(str(e) for e in json_result.errors)
|
||||
)
|
||||
|
||||
if json_result.object is demjson3.undefined:
|
||||
if json_result.object in (demjson3.syntax_error, demjson3.undefined):
|
||||
raise ValueError(
|
||||
f"Failed to parse JSON string: {json_str}", *json_result.errors
|
||||
)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Iterable, TypeVar
|
||||
from typing import TYPE_CHECKING, Callable, Iterable, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.models.command import Command
|
||||
|
||||
from autogpt.core.resource.model_providers import CompletionModelFunction
|
||||
from autogpt.models.command import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -12,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
def get_openai_command_specs(
|
||||
def function_specs_from_commands(
|
||||
commands: Iterable[Command],
|
||||
) -> list[CompletionModelFunction]:
|
||||
"""Get OpenAI-consumable function specs for the agent's available commands.
|
||||
@@ -20,7 +22,7 @@ def get_openai_command_specs(
|
||||
"""
|
||||
return [
|
||||
CompletionModelFunction(
|
||||
name=command.name,
|
||||
name=command.names[0],
|
||||
description=command.description,
|
||||
parameters={param.name: param.spec for param in command.parameters},
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .config import configure_chat_plugins, configure_logging
|
||||
from .config import configure_logging
|
||||
from .helpers import user_friendly_output
|
||||
from .log_cycle import (
|
||||
CURRENT_CONTEXT_FILE_NAME,
|
||||
@@ -13,7 +13,6 @@ from .log_cycle import (
|
||||
|
||||
__all__ = [
|
||||
"configure_logging",
|
||||
"configure_chat_plugins",
|
||||
"user_friendly_output",
|
||||
"CURRENT_CONTEXT_FILE_NAME",
|
||||
"NEXT_ACTION_FILE_NAME",
|
||||
|
||||
@@ -8,11 +8,9 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from openai._base_client import log as openai_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.speech import TTSConfig
|
||||
|
||||
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
|
||||
@@ -34,8 +32,6 @@ DEBUG_LOG_FORMAT = (
|
||||
SPEECH_OUTPUT_LOGGER = "VOICE"
|
||||
USER_FRIENDLY_OUTPUT_LOGGER = "USER_FRIENDLY_OUTPUT"
|
||||
|
||||
_chat_plugins: list[AutoGPTPluginTemplate] = []
|
||||
|
||||
|
||||
class LogFormatName(str, enum.Enum):
|
||||
SIMPLE = "simple"
|
||||
@@ -81,12 +77,14 @@ def configure_logging(
|
||||
log_format: Optional[LogFormatName | str] = None,
|
||||
log_file_format: Optional[LogFormatName | str] = None,
|
||||
plain_console_output: Optional[bool] = None,
|
||||
config: Optional[LoggingConfig] = None,
|
||||
tts_config: Optional[TTSConfig] = None,
|
||||
) -> None:
|
||||
"""Configure the native logging module, based on the environment config and any
|
||||
specified overrides.
|
||||
|
||||
Arguments override values specified in the environment.
|
||||
Overrides are also applied to `config`, if passed.
|
||||
|
||||
Should be usable as `configure_logging(**config.logging.dict())`, where
|
||||
`config.logging` is a `LoggingConfig` object.
|
||||
@@ -111,14 +109,16 @@ def configure_logging(
|
||||
elif not isinstance(log_file_format, LogFormatName):
|
||||
raise ValueError(f"Unknown log format '{log_format}'")
|
||||
|
||||
config = LoggingConfig.from_env()
|
||||
config = config or LoggingConfig.from_env()
|
||||
|
||||
# Aggregate arguments + env config
|
||||
level = logging.DEBUG if debug else level or config.level
|
||||
log_dir = log_dir or config.log_dir
|
||||
log_format = log_format or (LogFormatName.DEBUG if debug else config.log_format)
|
||||
log_file_format = log_file_format or log_format or config.log_file_format
|
||||
plain_console_output = (
|
||||
# Aggregate env config + arguments
|
||||
config.level = logging.DEBUG if debug else level or config.level
|
||||
config.log_dir = log_dir or config.log_dir
|
||||
config.log_format = log_format or (
|
||||
LogFormatName.DEBUG if debug else config.log_format
|
||||
)
|
||||
config.log_file_format = log_file_format or log_format or config.log_file_format
|
||||
config.plain_console_output = (
|
||||
plain_console_output
|
||||
if plain_console_output is not None
|
||||
else config.plain_console_output
|
||||
@@ -126,18 +126,18 @@ def configure_logging(
|
||||
|
||||
# Structured logging is used for cloud environments,
|
||||
# where logging to a file makes no sense.
|
||||
if log_format == LogFormatName.STRUCTURED:
|
||||
plain_console_output = True
|
||||
log_file_format = None
|
||||
if config.log_format == LogFormatName.STRUCTURED:
|
||||
config.plain_console_output = True
|
||||
config.log_file_format = None
|
||||
|
||||
# create log directory if it doesn't exist
|
||||
if not log_dir.exists():
|
||||
log_dir.mkdir()
|
||||
if not config.log_dir.exists():
|
||||
config.log_dir.mkdir()
|
||||
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
if log_format in (LogFormatName.DEBUG, LogFormatName.SIMPLE):
|
||||
console_format_template = TEXT_LOG_FORMAT_MAP[log_format]
|
||||
if config.log_format in (LogFormatName.DEBUG, LogFormatName.SIMPLE):
|
||||
console_format_template = TEXT_LOG_FORMAT_MAP[config.log_format]
|
||||
console_formatter = AutoGptFormatter(console_format_template)
|
||||
else:
|
||||
console_formatter = StructuredLoggingFormatter()
|
||||
@@ -145,7 +145,7 @@ def configure_logging(
|
||||
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(level)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
stdout.setFormatter(console_formatter)
|
||||
stderr = logging.StreamHandler()
|
||||
@@ -162,7 +162,7 @@ def configure_logging(
|
||||
user_friendly_output_logger = logging.getLogger(USER_FRIENDLY_OUTPUT_LOGGER)
|
||||
user_friendly_output_logger.setLevel(logging.INFO)
|
||||
user_friendly_output_logger.addHandler(
|
||||
typing_console_handler if not plain_console_output else stdout
|
||||
typing_console_handler if not config.plain_console_output else stdout
|
||||
)
|
||||
if tts_config:
|
||||
user_friendly_output_logger.addHandler(TTSHandler(tts_config))
|
||||
@@ -170,22 +170,26 @@ def configure_logging(
|
||||
user_friendly_output_logger.propagate = False
|
||||
|
||||
# File output handlers
|
||||
if log_file_format is not None:
|
||||
if level < logging.ERROR:
|
||||
file_output_format_template = TEXT_LOG_FORMAT_MAP[log_file_format]
|
||||
if config.log_file_format is not None:
|
||||
if config.level < logging.ERROR:
|
||||
file_output_format_template = TEXT_LOG_FORMAT_MAP[config.log_file_format]
|
||||
file_output_formatter = AutoGptFormatter(
|
||||
file_output_format_template, no_color=True
|
||||
)
|
||||
|
||||
# INFO log file handler
|
||||
activity_log_handler = logging.FileHandler(log_dir / LOG_FILE, "a", "utf-8")
|
||||
activity_log_handler.setLevel(level)
|
||||
activity_log_handler = logging.FileHandler(
|
||||
config.log_dir / LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
activity_log_handler.setLevel(config.level)
|
||||
activity_log_handler.setFormatter(file_output_formatter)
|
||||
log_handlers += [activity_log_handler]
|
||||
user_friendly_output_logger.addHandler(activity_log_handler)
|
||||
|
||||
# ERROR log file handler
|
||||
error_log_handler = logging.FileHandler(log_dir / ERROR_LOG_FILE, "a", "utf-8")
|
||||
error_log_handler = logging.FileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(
|
||||
AutoGptFormatter(DEBUG_LOG_FORMAT, no_color=True)
|
||||
@@ -196,7 +200,7 @@ def configure_logging(
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
format=console_format_template,
|
||||
level=level,
|
||||
level=config.level,
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
@@ -214,19 +218,3 @@ def configure_logging(
|
||||
|
||||
# Disable debug logging from OpenAI library
|
||||
openai_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def configure_chat_plugins(config: Config) -> None:
|
||||
"""Configure chat plugins for use by the logging module"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add chat plugins capable of report to logger
|
||||
if config.chat_messages_enabled:
|
||||
if _chat_plugins:
|
||||
_chat_plugins.clear()
|
||||
|
||||
for plugin in config.plugins:
|
||||
if hasattr(plugin, "can_handle_report") and plugin.can_handle_report():
|
||||
logger.debug(f"Loaded plugin into logger: {plugin.__class__.__name__}")
|
||||
_chat_plugins.append(plugin)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Optional
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from .config import SPEECH_OUTPUT_LOGGER, USER_FRIENDLY_OUTPUT_LOGGER, _chat_plugins
|
||||
from .config import SPEECH_OUTPUT_LOGGER, USER_FRIENDLY_OUTPUT_LOGGER
|
||||
|
||||
|
||||
def user_friendly_output(
|
||||
@@ -21,10 +21,6 @@ def user_friendly_output(
|
||||
"""
|
||||
logger = logging.getLogger(USER_FRIENDLY_OUTPUT_LOGGER)
|
||||
|
||||
if _chat_plugins:
|
||||
for plugin in _chat_plugins:
|
||||
plugin.report(f"{title}: {message}")
|
||||
|
||||
logger.log(
|
||||
level,
|
||||
message,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from typing import Any, Sequence, overload
|
||||
|
||||
import numpy as np
|
||||
@@ -51,16 +50,9 @@ async def get_embedding(
|
||||
|
||||
if isinstance(input, str):
|
||||
input = input.replace("\n", " ")
|
||||
|
||||
with suppress(NotImplementedError):
|
||||
return _get_embedding_with_plugin(input, config)
|
||||
|
||||
elif multiple and isinstance(input[0], str):
|
||||
input = [text.replace("\n", " ") for text in input]
|
||||
|
||||
with suppress(NotImplementedError):
|
||||
return [_get_embedding_with_plugin(i, config) for i in input]
|
||||
|
||||
model = config.embedding_model
|
||||
|
||||
logger.debug(
|
||||
@@ -86,13 +78,3 @@ async def get_embedding(
|
||||
)
|
||||
embeddings.append(result.embedding)
|
||||
return embeddings
|
||||
|
||||
|
||||
def _get_embedding_with_plugin(text: str, config: Config) -> Embedding:
|
||||
for plugin in config.plugins:
|
||||
if plugin.can_handle_text_embedding(text):
|
||||
embedding = plugin.handle_text_embedding(text)
|
||||
if embedding is not None:
|
||||
return embedding
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,31 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Iterator, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from autogpt.agents.base import BaseAgentActionProposal
|
||||
from autogpt.models.utils import ModelWithSummary
|
||||
from autogpt.processing.text import summarize_text
|
||||
from autogpt.prompts.utils import format_numbered_list, indent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.base import CommandArgs, CommandName
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||
|
||||
|
||||
class Action(BaseModel):
|
||||
name: str
|
||||
args: dict[str, Any]
|
||||
reasoning: str
|
||||
|
||||
def format_call(self) -> str:
|
||||
return (
|
||||
f"{self.name}"
|
||||
f"({', '.join([f'{a}={repr(v)}' for a, v in self.args.items()])})"
|
||||
)
|
||||
|
||||
|
||||
class ActionSuccessResult(BaseModel):
|
||||
outputs: Any
|
||||
status: Literal["success"] = "success"
|
||||
@@ -87,15 +77,22 @@ class ActionInterruptedByHuman(BaseModel):
|
||||
|
||||
ActionResult = ActionSuccessResult | ActionErrorResult | ActionInterruptedByHuman
|
||||
|
||||
AP = TypeVar("AP", bound=BaseAgentActionProposal)
|
||||
|
||||
class Episode(BaseModel):
|
||||
action: Action
|
||||
|
||||
class Episode(GenericModel, Generic[AP]):
|
||||
action: AP
|
||||
result: ActionResult | None
|
||||
summary: str | None = None
|
||||
|
||||
def format(self):
|
||||
step = f"Executed `{self.action.format_call()}`\n"
|
||||
step += f'- **Reasoning:** "{self.action.reasoning}"\n'
|
||||
step = f"Executed `{self.action.use_tool}`\n"
|
||||
reasoning = (
|
||||
_r.summary()
|
||||
if isinstance(_r := self.action.thoughts, ModelWithSummary)
|
||||
else _r
|
||||
)
|
||||
step += f'- **Reasoning:** "{reasoning}"\n'
|
||||
step += (
|
||||
"- **Status:** "
|
||||
f"`{self.result.status if self.result else 'did_not_finish'}`\n"
|
||||
@@ -114,28 +111,28 @@ class Episode(BaseModel):
|
||||
return step
|
||||
|
||||
def __str__(self) -> str:
|
||||
executed_action = f"Executed `{self.action.format_call()}`"
|
||||
executed_action = f"Executed `{self.action.use_tool}`"
|
||||
action_result = f": {self.result}" if self.result else "."
|
||||
return executed_action + action_result
|
||||
|
||||
|
||||
class EpisodicActionHistory(BaseModel):
|
||||
class EpisodicActionHistory(GenericModel, Generic[AP]):
|
||||
"""Utility container for an action history"""
|
||||
|
||||
episodes: list[Episode] = Field(default_factory=list)
|
||||
episodes: list[Episode[AP]] = Field(default_factory=list)
|
||||
cursor: int = 0
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def current_episode(self) -> Episode | None:
|
||||
def current_episode(self) -> Episode[AP] | None:
|
||||
if self.cursor == len(self):
|
||||
return None
|
||||
return self[self.cursor]
|
||||
|
||||
def __getitem__(self, key: int) -> Episode:
|
||||
def __getitem__(self, key: int) -> Episode[AP]:
|
||||
return self.episodes[key]
|
||||
|
||||
def __iter__(self) -> Iterator[Episode]:
|
||||
def __iter__(self) -> Iterator[Episode[AP]]:
|
||||
return iter(self.episodes)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -144,7 +141,7 @@ class EpisodicActionHistory(BaseModel):
|
||||
def __bool__(self) -> bool:
|
||||
return len(self.episodes) > 0
|
||||
|
||||
def register_action(self, action: Action) -> None:
|
||||
def register_action(self, action: AP) -> None:
|
||||
if not self.current_episode:
|
||||
self.episodes.append(Episode(action=action, result=None))
|
||||
assert self.current_episode
|
||||
@@ -160,15 +157,6 @@ class EpisodicActionHistory(BaseModel):
|
||||
self.current_episode.result = result
|
||||
self.cursor = len(self.episodes)
|
||||
|
||||
def matches_last_command(
|
||||
self, command_name: CommandName, arguments: CommandArgs
|
||||
) -> bool:
|
||||
"""Check if the last command matches the given name and arguments."""
|
||||
if len(self.episodes) > 0:
|
||||
last_command = self.episodes[-1].action
|
||||
return last_command.name == command_name and last_command.args == arguments
|
||||
return False
|
||||
|
||||
def rewind(self, number_of_episodes: int = 0) -> None:
|
||||
"""Resets the history to an earlier state.
|
||||
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
"""Handles loading of plugins."""
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
PromptGenerator = TypeVar("PromptGenerator")
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class BaseOpenAIPlugin(AutoGPTPluginTemplate):
|
||||
"""
|
||||
This is a BaseOpenAIPlugin class for generating AutoGPT plugins.
|
||||
"""
|
||||
|
||||
def __init__(self, manifests_specs_clients: dict):
|
||||
# super().__init__()
|
||||
self._name = manifests_specs_clients["manifest"]["name_for_model"]
|
||||
self._version = manifests_specs_clients["manifest"]["schema_version"]
|
||||
self._description = manifests_specs_clients["manifest"]["description_for_model"]
|
||||
self._client = manifests_specs_clients["client"]
|
||||
self._manifest = manifests_specs_clients["manifest"]
|
||||
self._openapi_spec = manifests_specs_clients["openapi_spec"]
|
||||
|
||||
def can_handle_on_response(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the on_response method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the on_response method."""
|
||||
return False
|
||||
|
||||
def on_response(self, response: str, *args, **kwargs) -> str:
|
||||
"""This method is called when a response is received from the model."""
|
||||
return response
|
||||
|
||||
def can_handle_post_prompt(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_prompt method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_prompt method."""
|
||||
return False
|
||||
|
||||
def post_prompt(self, prompt: PromptGenerator) -> PromptGenerator:
|
||||
"""This method is called just after the generate_prompt is called,
|
||||
but actually before the prompt is generated.
|
||||
Args:
|
||||
prompt (PromptGenerator): The prompt generator.
|
||||
Returns:
|
||||
PromptGenerator: The prompt generator.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
def can_handle_on_planning(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the on_planning method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the on_planning method."""
|
||||
return False
|
||||
|
||||
def on_planning(
|
||||
self, prompt: PromptGenerator, messages: List[Message]
|
||||
) -> Optional[str]:
|
||||
"""This method is called before the planning chat completion is done.
|
||||
Args:
|
||||
prompt (PromptGenerator): The prompt generator.
|
||||
messages (List[str]): The list of messages.
|
||||
"""
|
||||
|
||||
def can_handle_post_planning(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_planning method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_planning method."""
|
||||
return False
|
||||
|
||||
def post_planning(self, response: str) -> str:
|
||||
"""This method is called after the planning chat completion is done.
|
||||
Args:
|
||||
response (str): The response.
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
return response
|
||||
|
||||
def can_handle_pre_instruction(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the pre_instruction method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the pre_instruction method."""
|
||||
return False
|
||||
|
||||
def pre_instruction(self, messages: List[Message]) -> List[Message]:
|
||||
"""This method is called before the instruction chat is done.
|
||||
Args:
|
||||
messages (List[Message]): The list of context messages.
|
||||
Returns:
|
||||
List[Message]: The resulting list of messages.
|
||||
"""
|
||||
return messages
|
||||
|
||||
def can_handle_on_instruction(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the on_instruction method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the on_instruction method."""
|
||||
return False
|
||||
|
||||
def on_instruction(self, messages: List[Message]) -> Optional[str]:
|
||||
"""This method is called when the instruction chat is done.
|
||||
Args:
|
||||
messages (List[Message]): The list of context messages.
|
||||
Returns:
|
||||
Optional[str]: The resulting message.
|
||||
"""
|
||||
|
||||
def can_handle_post_instruction(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_instruction method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_instruction method."""
|
||||
return False
|
||||
|
||||
def post_instruction(self, response: str) -> str:
|
||||
"""This method is called after the instruction chat is done.
|
||||
Args:
|
||||
response (str): The response.
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
return response
|
||||
|
||||
def can_handle_pre_command(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the pre_command method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the pre_command method."""
|
||||
return False
|
||||
|
||||
def pre_command(
|
||||
self, command_name: str, arguments: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""This method is called before the command is executed.
|
||||
Args:
|
||||
command_name (str): The command name.
|
||||
arguments (Dict[str, Any]): The arguments.
|
||||
Returns:
|
||||
Tuple[str, Dict[str, Any]]: The command name and the arguments.
|
||||
"""
|
||||
return command_name, arguments
|
||||
|
||||
def can_handle_post_command(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_command method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_command method."""
|
||||
return False
|
||||
|
||||
def post_command(self, command_name: str, response: str) -> str:
|
||||
"""This method is called after the command is executed.
|
||||
Args:
|
||||
command_name (str): The command name.
|
||||
response (str): The response.
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
return response
|
||||
|
||||
def can_handle_chat_completion(
|
||||
self, messages: Dict[Any, Any], model: str, temperature: float, max_tokens: int
|
||||
) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the chat_completion method.
|
||||
Args:
|
||||
messages (List[Message]): The messages.
|
||||
model (str): The model name.
|
||||
temperature (float): The temperature.
|
||||
max_tokens (int): The max tokens.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the chat_completion method."""
|
||||
return False
|
||||
|
||||
def handle_chat_completion(
|
||||
self, messages: List[Message], model: str, temperature: float, max_tokens: int
|
||||
) -> str:
|
||||
"""This method is called when the chat completion is done.
|
||||
Args:
|
||||
messages (List[Message]): The messages.
|
||||
model (str): The model name.
|
||||
temperature (float): The temperature.
|
||||
max_tokens (int): The max tokens.
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
|
||||
def can_handle_text_embedding(self, text: str) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the text_embedding method.
|
||||
|
||||
Args:
|
||||
text (str): The text to be convert to embedding.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the text_embedding method."""
|
||||
return False
|
||||
|
||||
def handle_text_embedding(self, text: str) -> list[float]:
|
||||
"""This method is called to create a text embedding.
|
||||
|
||||
Args:
|
||||
text (str): The text to be convert to embedding.
|
||||
Returns:
|
||||
list[float]: The created embedding vector.
|
||||
"""
|
||||
|
||||
def can_handle_user_input(self, user_input: str) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the user_input method.
|
||||
|
||||
Args:
|
||||
user_input (str): The user input.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the user_input method."""
|
||||
return False
|
||||
|
||||
def user_input(self, user_input: str) -> str:
|
||||
"""This method is called to request user input to the user.
|
||||
|
||||
Args:
|
||||
user_input (str): The question or prompt to ask the user.
|
||||
|
||||
Returns:
|
||||
str: The user input.
|
||||
"""
|
||||
|
||||
def can_handle_report(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the report method.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the report method."""
|
||||
return False
|
||||
|
||||
def report(self, message: str) -> None:
|
||||
"""This method is called to report a message to the user.
|
||||
|
||||
Args:
|
||||
message (str): The message to report.
|
||||
"""
|
||||
@@ -1,11 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.base import BaseAgent
|
||||
from autogpt.config import Config
|
||||
from typing import Any, Callable
|
||||
|
||||
from .command_parameter import CommandParameter
|
||||
from .context_item import ContextItem
|
||||
@@ -25,40 +21,42 @@ class Command:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
names: list[str],
|
||||
description: str,
|
||||
method: Callable[..., CommandOutput],
|
||||
parameters: list[CommandParameter],
|
||||
enabled: Literal[True] | Callable[[Config], bool] = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
aliases: list[str] = [],
|
||||
available: bool | Callable[[BaseAgent], bool] = True,
|
||||
):
|
||||
self.name = name
|
||||
# Check if all parameters are provided
|
||||
if not self._parameters_match(method, parameters):
|
||||
raise ValueError(
|
||||
f"Command {names[0]} has different parameters than provided schema"
|
||||
)
|
||||
self.names = names
|
||||
self.description = description
|
||||
self.method = method
|
||||
self.parameters = parameters
|
||||
self.enabled = enabled
|
||||
self.disabled_reason = disabled_reason
|
||||
self.aliases = aliases
|
||||
self.available = available
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
return inspect.iscoroutinefunction(self.method)
|
||||
|
||||
def __call__(self, *args, agent: BaseAgent, **kwargs) -> Any:
|
||||
if callable(self.enabled) and not self.enabled(agent.legacy_config):
|
||||
if self.disabled_reason:
|
||||
raise RuntimeError(
|
||||
f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||
)
|
||||
raise RuntimeError(f"Command '{self.name}' is disabled")
|
||||
def _parameters_match(
|
||||
self, func: Callable, parameters: list[CommandParameter]
|
||||
) -> bool:
|
||||
# Get the function's signature
|
||||
signature = inspect.signature(func)
|
||||
# Extract parameter names, ignoring 'self' for methods
|
||||
func_param_names = [
|
||||
param.name
|
||||
for param in signature.parameters.values()
|
||||
if param.name != "self"
|
||||
]
|
||||
names = [param.name for param in parameters]
|
||||
# Check if sorted lists of names/keys are equal
|
||||
return sorted(func_param_names) == sorted(names)
|
||||
|
||||
if not self.available or callable(self.available) and not self.available(agent):
|
||||
raise RuntimeError(f"Command '{self.name}' is not available")
|
||||
|
||||
return self.method(*args, **kwargs, agent=agent)
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
return self.method(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
params = [
|
||||
@@ -67,6 +65,18 @@ class Command:
|
||||
for param in self.parameters
|
||||
]
|
||||
return (
|
||||
f"{self.name}: {self.description.rstrip('.')}. "
|
||||
f"{self.names[0]}: {self.description.rstrip('.')}. "
|
||||
f"Params: ({', '.join(params)})"
|
||||
)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
# Accessed on the class, not an instance
|
||||
return self
|
||||
# Bind the method to the instance
|
||||
return Command(
|
||||
self.names,
|
||||
self.description,
|
||||
self.method.__get__(instance, owner),
|
||||
self.parameters,
|
||||
)
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, Iterator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.base import BaseAgent
|
||||
from autogpt.config import Config
|
||||
|
||||
|
||||
from autogpt.command_decorator import AUTO_GPT_COMMAND_IDENTIFIER
|
||||
from autogpt.models.command import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommandRegistry:
|
||||
"""
|
||||
The CommandRegistry class is a manager for a collection of Command objects.
|
||||
It allows the registration, modification, and retrieval of Command objects,
|
||||
as well as the scanning and loading of command plugins from a specified
|
||||
directory.
|
||||
"""
|
||||
|
||||
commands: dict[str, Command]
|
||||
commands_aliases: dict[str, Command]
|
||||
|
||||
# Alternative way to structure the registry; currently redundant with self.commands
|
||||
categories: dict[str, CommandCategory]
|
||||
|
||||
@dataclass
|
||||
class CommandCategory:
|
||||
name: str
|
||||
title: str
|
||||
description: str
|
||||
commands: list[Command] = field(default_factory=list[Command])
|
||||
modules: list[ModuleType] = field(default_factory=list[ModuleType])
|
||||
|
||||
def __init__(self):
|
||||
self.commands = {}
|
||||
self.commands_aliases = {}
|
||||
self.categories = {}
|
||||
|
||||
def __contains__(self, command_name: str):
|
||||
return command_name in self.commands or command_name in self.commands_aliases
|
||||
|
||||
def _import_module(self, module_name: str) -> Any:
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
def _reload_module(self, module: Any) -> Any:
|
||||
return importlib.reload(module)
|
||||
|
||||
def register(self, cmd: Command) -> None:
|
||||
if cmd.name in self.commands:
|
||||
logger.warning(
|
||||
f"Command '{cmd.name}' already registered and will be overwritten!"
|
||||
)
|
||||
self.commands[cmd.name] = cmd
|
||||
|
||||
if cmd.name in self.commands_aliases:
|
||||
logger.warning(
|
||||
f"Command '{cmd.name}' will overwrite alias with the same name of "
|
||||
f"'{self.commands_aliases[cmd.name]}'!"
|
||||
)
|
||||
for alias in cmd.aliases:
|
||||
self.commands_aliases[alias] = cmd
|
||||
|
||||
def unregister(self, command: Command) -> None:
|
||||
if command.name in self.commands:
|
||||
del self.commands[command.name]
|
||||
for alias in command.aliases:
|
||||
del self.commands_aliases[alias]
|
||||
else:
|
||||
raise KeyError(f"Command '{command.name}' not found in registry.")
|
||||
|
||||
def reload_commands(self) -> None:
|
||||
"""Reloads all loaded command plugins."""
|
||||
for cmd_name in self.commands:
|
||||
cmd = self.commands[cmd_name]
|
||||
module = self._import_module(cmd.__module__)
|
||||
reloaded_module = self._reload_module(module)
|
||||
if hasattr(reloaded_module, "register"):
|
||||
reloaded_module.register(self)
|
||||
|
||||
def get_command(self, name: str) -> Command | None:
|
||||
if name in self.commands:
|
||||
return self.commands[name]
|
||||
|
||||
if name in self.commands_aliases:
|
||||
return self.commands_aliases[name]
|
||||
|
||||
def call(self, command_name: str, agent: BaseAgent, **kwargs) -> Any:
|
||||
if command := self.get_command(command_name):
|
||||
return command(**kwargs, agent=agent)
|
||||
raise KeyError(f"Command '{command_name}' not found in registry")
|
||||
|
||||
def list_available_commands(self, agent: BaseAgent) -> Iterator[Command]:
|
||||
"""Iterates over all registered commands and yields those that are available.
|
||||
|
||||
Params:
|
||||
agent (BaseAgent): The agent that the commands will be checked against.
|
||||
|
||||
Yields:
|
||||
Command: The next available command.
|
||||
"""
|
||||
|
||||
for cmd in self.commands.values():
|
||||
available = cmd.available
|
||||
if callable(cmd.available):
|
||||
available = cmd.available(agent)
|
||||
if available:
|
||||
yield cmd
|
||||
|
||||
# def command_specs(self) -> str:
|
||||
# """
|
||||
# Returns a technical declaration of all commands in the registry,
|
||||
# for use in a prompt.
|
||||
# """
|
||||
#
|
||||
# Declaring functions or commands should be done in a model-specific way to
|
||||
# achieve optimal results. For this reason, it should NOT be implemented here,
|
||||
# but in an LLM provider module.
|
||||
# MUST take command AVAILABILITY into account.
|
||||
|
||||
@staticmethod
|
||||
def with_command_modules(modules: list[str], config: Config) -> CommandRegistry:
|
||||
new_registry = CommandRegistry()
|
||||
|
||||
logger.debug(
|
||||
"The following command categories are disabled: "
|
||||
f"{config.disabled_command_categories}"
|
||||
)
|
||||
enabled_command_modules = [
|
||||
x for x in modules if x not in config.disabled_command_categories
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"The following command categories are enabled: {enabled_command_modules}"
|
||||
)
|
||||
|
||||
for command_module in enabled_command_modules:
|
||||
new_registry.import_command_module(command_module)
|
||||
|
||||
# Unregister commands that are incompatible with the current config
|
||||
for command in [c for c in new_registry.commands.values()]:
|
||||
if callable(command.enabled) and not command.enabled(config):
|
||||
new_registry.unregister(command)
|
||||
logger.debug(
|
||||
f"Unregistering incompatible command '{command.name}':"
|
||||
f" \"{command.disabled_reason or 'Disabled by current config.'}\""
|
||||
)
|
||||
|
||||
return new_registry
|
||||
|
||||
def import_command_module(self, module_name: str) -> None:
|
||||
"""
|
||||
Imports the specified Python module containing command plugins.
|
||||
|
||||
This method imports the associated module and registers any functions or
|
||||
classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute
|
||||
as `Command` objects. The registered `Command` objects are then added to the
|
||||
`commands` dictionary of the `CommandRegistry` object.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to import for command plugins.
|
||||
"""
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
category = self.register_module_category(module)
|
||||
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
|
||||
command = None
|
||||
|
||||
# Register decorated functions
|
||||
if getattr(attr, AUTO_GPT_COMMAND_IDENTIFIER, False):
|
||||
command = attr.command
|
||||
|
||||
# Register command classes
|
||||
elif (
|
||||
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
|
||||
):
|
||||
command = attr()
|
||||
|
||||
if command:
|
||||
self.register(command)
|
||||
category.commands.append(command)
|
||||
|
||||
def register_module_category(self, module: ModuleType) -> CommandCategory:
|
||||
if not (category_name := getattr(module, "COMMAND_CATEGORY", None)):
|
||||
raise ValueError(f"Cannot import invalid command module {module.__name__}")
|
||||
|
||||
if category_name not in self.categories:
|
||||
self.categories[category_name] = CommandRegistry.CommandCategory(
|
||||
name=category_name,
|
||||
title=getattr(
|
||||
module, "COMMAND_CATEGORY_TITLE", category_name.capitalize()
|
||||
),
|
||||
description=getattr(module, "__doc__", ""),
|
||||
)
|
||||
|
||||
category = self.categories[category_name]
|
||||
if module not in category.modules:
|
||||
category.modules.append(module)
|
||||
|
||||
return category
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
import os.path
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogpt.commands.file_operations_utils import decode_textual_file
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.utils.file_operations_utils import decode_textual_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,67 +24,51 @@ class ContextItem(ABC):
|
||||
"""A string indicating the source location of the context item"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def content(self) -> str:
|
||||
def get_content(self, workspace: FileStorage) -> str:
|
||||
"""The content represented by the context item"""
|
||||
...
|
||||
|
||||
def fmt(self) -> str:
|
||||
def fmt(self, workspace: FileStorage) -> str:
|
||||
return (
|
||||
f"{self.description} (source: {self.source})\n"
|
||||
"```\n"
|
||||
f"{self.content}\n"
|
||||
f"{self.get_content(workspace)}\n"
|
||||
"```"
|
||||
)
|
||||
|
||||
|
||||
class FileContextItem(BaseModel, ContextItem):
|
||||
file_path_in_workspace: Path
|
||||
workspace_path: Path
|
||||
|
||||
@property
|
||||
def file_path(self) -> Path:
|
||||
return self.workspace_path / self.file_path_in_workspace
|
||||
path: Path
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"The current content of the file '{self.file_path_in_workspace}'"
|
||||
return f"The current content of the file '{self.path}'"
|
||||
|
||||
@property
|
||||
def source(self) -> str:
|
||||
return str(self.file_path_in_workspace)
|
||||
return str(self.path)
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
# TODO: use workspace.open_file()
|
||||
with open(self.file_path, "rb") as file:
|
||||
return decode_textual_file(file, os.path.splitext(file.name)[1], logger)
|
||||
def get_content(self, workspace: FileStorage) -> str:
|
||||
with workspace.open_file(self.path, "r", True) as file:
|
||||
return decode_textual_file(file, self.path.suffix, logger)
|
||||
|
||||
|
||||
class FolderContextItem(BaseModel, ContextItem):
|
||||
path_in_workspace: Path
|
||||
workspace_path: Path
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
return self.workspace_path / self.path_in_workspace
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.path.exists(), "Selected path does not exist"
|
||||
assert self.path.is_dir(), "Selected path is not a directory"
|
||||
path: Path
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"The contents of the folder '{self.path_in_workspace}' in the workspace"
|
||||
return f"The contents of the folder '{self.path}' in the workspace"
|
||||
|
||||
@property
|
||||
def source(self) -> str:
|
||||
return str(self.path_in_workspace)
|
||||
return str(self.path)
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
items = [f"{p.name}{'/' if p.is_dir() else ''}" for p in self.path.iterdir()]
|
||||
def get_content(self, workspace: FileStorage) -> str:
|
||||
files = [str(p) for p in workspace.list_files(self.path)]
|
||||
folders = [f"{str(p)}/" for p in workspace.list_folders(self.path)]
|
||||
items = folders + files
|
||||
items.sort()
|
||||
return "\n".join(items)
|
||||
|
||||
|
||||
10
autogpts/autogpt/autogpt/models/utils.py
Normal file
10
autogpts/autogpt/autogpt/models/utils.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelWithSummary(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def summary(self) -> str:
|
||||
"""Should produce a human readable summary of the model content."""
|
||||
pass
|
||||
@@ -1,330 +0,0 @@
|
||||
"""Handles loading of plugins."""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List
|
||||
from urllib.parse import urlparse
|
||||
from zipimport import ZipImportError, zipimporter
|
||||
|
||||
import openapi_python_client
|
||||
import requests
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from openapi_python_client.config import Config as OpenAPIConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
from autogpt.models.base_open_ai_plugin import BaseOpenAIPlugin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def inspect_zip_for_modules(zip_path: str) -> list[str]:
|
||||
"""
|
||||
Inspect a zipfile for a modules.
|
||||
|
||||
Args:
|
||||
zip_path (str): Path to the zipfile.
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of module names found or empty list if none were found.
|
||||
"""
|
||||
result = []
|
||||
with zipfile.ZipFile(zip_path, "r") as zfile:
|
||||
for name in zfile.namelist():
|
||||
if name.endswith("__init__.py") and not name.startswith("__MACOSX"):
|
||||
logger.debug(f"Found module '{name}' in the zipfile at: {name}")
|
||||
result.append(name)
|
||||
if len(result) == 0:
|
||||
logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.")
|
||||
return result
|
||||
|
||||
|
||||
def write_dict_to_json_file(data: dict, file_path: str) -> None:
|
||||
"""
|
||||
Write a dictionary to a JSON file.
|
||||
Args:
|
||||
data (dict): Dictionary to write.
|
||||
file_path (str): Path to the file.
|
||||
"""
|
||||
with open(file_path, "w") as file:
|
||||
json.dump(data, file, indent=4)
|
||||
|
||||
|
||||
def fetch_openai_plugins_manifest_and_spec(config: Config) -> dict:
|
||||
"""
|
||||
Fetch the manifest for a list of OpenAI plugins.
|
||||
Args:
|
||||
urls (List): List of URLs to fetch.
|
||||
Returns:
|
||||
dict: per url dictionary of manifest and spec.
|
||||
"""
|
||||
# TODO add directory scan
|
||||
manifests = {}
|
||||
for url in config.plugins_openai:
|
||||
openai_plugin_client_dir = f"{config.plugins_dir}/openai/{urlparse(url).netloc}"
|
||||
create_directory_if_not_exists(openai_plugin_client_dir)
|
||||
if not os.path.exists(f"{openai_plugin_client_dir}/ai-plugin.json"):
|
||||
try:
|
||||
response = requests.get(f"{url}/.well-known/ai-plugin.json")
|
||||
if response.status_code == 200:
|
||||
manifest = response.json()
|
||||
if manifest["schema_version"] != "v1":
|
||||
logger.warning(
|
||||
"Unsupported manifest version: "
|
||||
f"{manifest['schem_version']} for {url}"
|
||||
)
|
||||
continue
|
||||
if manifest["api"]["type"] != "openapi":
|
||||
logger.warning(
|
||||
f"Unsupported API type: {manifest['api']['type']} for {url}"
|
||||
)
|
||||
continue
|
||||
write_dict_to_json_file(
|
||||
manifest, f"{openai_plugin_client_dir}/ai-plugin.json"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to fetch manifest for {url}: {response.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Error while requesting manifest from {url}: {e}")
|
||||
else:
|
||||
logger.info(f"Manifest for {url} already exists")
|
||||
manifest = json.load(open(f"{openai_plugin_client_dir}/ai-plugin.json"))
|
||||
if not os.path.exists(f"{openai_plugin_client_dir}/openapi.json"):
|
||||
openapi_spec = openapi_python_client._get_document(
|
||||
url=manifest["api"]["url"], path=None, timeout=5
|
||||
)
|
||||
write_dict_to_json_file(
|
||||
openapi_spec, f"{openai_plugin_client_dir}/openapi.json"
|
||||
)
|
||||
else:
|
||||
logger.info(f"OpenAPI spec for {url} already exists")
|
||||
openapi_spec = json.load(open(f"{openai_plugin_client_dir}/openapi.json"))
|
||||
manifests[url] = {"manifest": manifest, "openapi_spec": openapi_spec}
|
||||
return manifests
|
||||
|
||||
|
||||
def create_directory_if_not_exists(directory_path: str) -> bool:
|
||||
"""
|
||||
Create a directory if it does not exist.
|
||||
Args:
|
||||
directory_path (str): Path to the directory.
|
||||
Returns:
|
||||
bool: True if the directory was created, else False.
|
||||
"""
|
||||
if not os.path.exists(directory_path):
|
||||
try:
|
||||
os.makedirs(directory_path)
|
||||
logger.debug(f"Created directory: {directory_path}")
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.warning(f"Error creating directory {directory_path}: {e}")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"Directory {directory_path} already exists")
|
||||
return True
|
||||
|
||||
|
||||
def initialize_openai_plugins(manifests_specs: dict, config: Config) -> dict:
|
||||
"""
|
||||
Initialize OpenAI plugins.
|
||||
Args:
|
||||
manifests_specs (dict): per url dictionary of manifest and spec.
|
||||
config (Config): Config instance including plugins config
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
Returns:
|
||||
dict: per url dictionary of manifest, spec and client.
|
||||
"""
|
||||
openai_plugins_dir = f"{config.plugins_dir}/openai"
|
||||
if create_directory_if_not_exists(openai_plugins_dir):
|
||||
for url, manifest_spec in manifests_specs.items():
|
||||
openai_plugin_client_dir = f"{openai_plugins_dir}/{urlparse(url).hostname}"
|
||||
_meta_option = (openapi_python_client.MetaType.SETUP,)
|
||||
_config = OpenAPIConfig(
|
||||
**{
|
||||
"project_name_override": "client",
|
||||
"package_name_override": "client",
|
||||
}
|
||||
)
|
||||
prev_cwd = Path.cwd()
|
||||
os.chdir(openai_plugin_client_dir)
|
||||
|
||||
if not os.path.exists("client"):
|
||||
client_results = openapi_python_client.create_new_client(
|
||||
url=manifest_spec["manifest"]["api"]["url"],
|
||||
path=None,
|
||||
meta=_meta_option,
|
||||
config=_config,
|
||||
)
|
||||
if client_results:
|
||||
logger.warning(
|
||||
f"Error creating OpenAPI client: {client_results[0].header} \n"
|
||||
f" details: {client_results[0].detail}"
|
||||
)
|
||||
continue
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"client", "client/client/client.py"
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
finally:
|
||||
os.chdir(prev_cwd)
|
||||
|
||||
client = module.Client(base_url=url)
|
||||
manifest_spec["client"] = client
|
||||
return manifests_specs
|
||||
|
||||
|
||||
def instantiate_openai_plugin_clients(manifests_specs_clients: dict) -> dict:
|
||||
"""
|
||||
Instantiates BaseOpenAIPlugin instances for each OpenAI plugin.
|
||||
Args:
|
||||
manifests_specs_clients (dict): per url dictionary of manifest, spec and client.
|
||||
config (Config): Config instance including plugins config
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
Returns:
|
||||
plugins (dict): per url dictionary of BaseOpenAIPlugin instances.
|
||||
|
||||
"""
|
||||
plugins = {}
|
||||
for url, manifest_spec_client in manifests_specs_clients.items():
|
||||
plugins[url] = BaseOpenAIPlugin(manifest_spec_client)
|
||||
return plugins
|
||||
|
||||
|
||||
def scan_plugins(config: Config) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
Args:
|
||||
config (Config): Config instance including plugins config
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Path]]: List of plugins.
|
||||
"""
|
||||
loaded_plugins = []
|
||||
# Generic plugins
|
||||
plugins_path = Path(config.plugins_dir)
|
||||
|
||||
plugins_config = config.plugins_config
|
||||
# Directory-based plugins
|
||||
for plugin_path in [f for f in Path(config.plugins_dir).iterdir() if f.is_dir()]:
|
||||
# Avoid going into __pycache__ or other hidden directories
|
||||
if plugin_path.name.startswith("__"):
|
||||
continue
|
||||
|
||||
plugin_module_name = plugin_path.name
|
||||
qualified_module_name = ".".join(plugin_path.parts)
|
||||
|
||||
try:
|
||||
plugin = importlib.import_module(qualified_module_name)
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"Failed to load {qualified_module_name} from {plugin_path}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not plugins_config.is_enabled(plugin_module_name):
|
||||
logger.warning(
|
||||
f"Plugin folder {plugin_module_name} found but not configured. "
|
||||
"If this is a legitimate plugin, please add it to plugins_config.yaml "
|
||||
f"(key: {plugin_module_name})."
|
||||
)
|
||||
continue
|
||||
|
||||
for _, class_obj in inspect.getmembers(plugin):
|
||||
if (
|
||||
hasattr(class_obj, "_abc_impl")
|
||||
and AutoGPTPluginTemplate in class_obj.__bases__
|
||||
):
|
||||
loaded_plugins.append(class_obj())
|
||||
|
||||
# Zip-based plugins
|
||||
for plugin in plugins_path.glob("*.zip"):
|
||||
if moduleList := inspect_zip_for_modules(str(plugin)):
|
||||
for module in moduleList:
|
||||
plugin = Path(plugin)
|
||||
module = Path(module)
|
||||
logger.debug(f"Zipped Plugin: {plugin}, Module: {module}")
|
||||
zipped_package = zipimporter(str(plugin))
|
||||
try:
|
||||
zipped_module = zipped_package.load_module(str(module.parent))
|
||||
except ZipImportError as e:
|
||||
logger.error(f"Failed to load {module.parent} from {plugin}: {e}")
|
||||
continue
|
||||
|
||||
for key in dir(zipped_module):
|
||||
if key.startswith("__"):
|
||||
continue
|
||||
|
||||
a_module = getattr(zipped_module, key)
|
||||
if not inspect.isclass(a_module):
|
||||
continue
|
||||
|
||||
if (
|
||||
issubclass(a_module, AutoGPTPluginTemplate)
|
||||
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||
):
|
||||
plugin_name = a_module.__name__
|
||||
plugin_configured = plugins_config.get(plugin_name) is not None
|
||||
plugin_enabled = plugins_config.is_enabled(plugin_name)
|
||||
|
||||
if plugin_configured and plugin_enabled:
|
||||
logger.debug(
|
||||
f"Loading plugin {plugin_name}. "
|
||||
"Enabled in plugins_config.yaml."
|
||||
)
|
||||
loaded_plugins.append(a_module())
|
||||
elif plugin_configured and not plugin_enabled:
|
||||
logger.debug(
|
||||
f"Not loading plugin {plugin_name}. "
|
||||
"Disabled in plugins_config.yaml."
|
||||
)
|
||||
elif not plugin_configured:
|
||||
logger.warning(
|
||||
f"Not loading plugin {plugin_name}. "
|
||||
f"No entry for '{plugin_name}' in plugins_config.yaml. "
|
||||
"Note: Zipped plugins should use the class name "
|
||||
f"({plugin_name}) as the key."
|
||||
)
|
||||
else:
|
||||
if (
|
||||
module_name := getattr(a_module, "__name__", str(a_module))
|
||||
) != "AutoGPTPluginTemplate":
|
||||
logger.debug(
|
||||
f"Skipping '{module_name}' because it doesn't subclass "
|
||||
"AutoGPTPluginTemplate."
|
||||
)
|
||||
|
||||
# OpenAI plugins
|
||||
if config.plugins_openai:
|
||||
manifests_specs = fetch_openai_plugins_manifest_and_spec(config)
|
||||
if manifests_specs.keys():
|
||||
manifests_specs_clients = initialize_openai_plugins(manifests_specs, config)
|
||||
for url, openai_plugin_meta in manifests_specs_clients.items():
|
||||
if not plugins_config.is_enabled(url):
|
||||
plugin_name = openai_plugin_meta["manifest"]["name_for_model"]
|
||||
logger.warning(
|
||||
f"OpenAI Plugin {plugin_name} found but not configured"
|
||||
)
|
||||
continue
|
||||
|
||||
plugin = BaseOpenAIPlugin(openai_plugin_meta)
|
||||
loaded_plugins.append(plugin)
|
||||
|
||||
if loaded_plugins:
|
||||
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
|
||||
for plugin in loaded_plugins:
|
||||
logger.info(f"{plugin._name}: {plugin._version} - {plugin._description}")
|
||||
return loaded_plugins
|
||||
@@ -1,11 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PluginConfig(BaseModel):
|
||||
"""Class for holding configuration of a single plugin"""
|
||||
|
||||
name: str
|
||||
enabled: bool = False
|
||||
config: dict[str, Any] = None
|
||||
@@ -1,118 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt.plugins.plugin_config import PluginConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginsConfig(BaseModel):
|
||||
"""Class for holding configuration of all plugins"""
|
||||
|
||||
plugins: dict[str, PluginConfig]
|
||||
|
||||
def __repr__(self):
|
||||
return f"PluginsConfig({self.plugins})"
|
||||
|
||||
def get(self, name: str) -> Union[PluginConfig, None]:
|
||||
return self.plugins.get(name)
|
||||
|
||||
def is_enabled(self, name) -> bool:
|
||||
plugin_config = self.plugins.get(name)
|
||||
return plugin_config is not None and plugin_config.enabled
|
||||
|
||||
@classmethod
|
||||
def load_config(
|
||||
cls,
|
||||
plugins_config_file: Path,
|
||||
plugins_denylist: list[str],
|
||||
plugins_allowlist: list[str],
|
||||
) -> "PluginsConfig":
|
||||
empty_config = cls(plugins={})
|
||||
|
||||
try:
|
||||
config_data = cls.deserialize_config_file(
|
||||
plugins_config_file,
|
||||
plugins_denylist,
|
||||
plugins_allowlist,
|
||||
)
|
||||
if type(config_data) is not dict:
|
||||
logger.error(
|
||||
f"Expected plugins config to be a dict, got {type(config_data)}."
|
||||
" Continuing without plugins."
|
||||
)
|
||||
return empty_config
|
||||
return cls(plugins=config_data)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(
|
||||
f"Plugin config is invalid. Continuing without plugins. Error: {e}"
|
||||
)
|
||||
return empty_config
|
||||
|
||||
@classmethod
|
||||
def deserialize_config_file(
|
||||
cls,
|
||||
plugins_config_file: Path,
|
||||
plugins_denylist: list[str],
|
||||
plugins_allowlist: list[str],
|
||||
) -> dict[str, PluginConfig]:
|
||||
if not plugins_config_file.is_file():
|
||||
logger.warning("plugins_config.yaml does not exist, creating base config.")
|
||||
cls.create_empty_plugins_config(
|
||||
plugins_config_file,
|
||||
plugins_denylist,
|
||||
plugins_allowlist,
|
||||
)
|
||||
|
||||
with open(plugins_config_file, "r") as f:
|
||||
plugins_config = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
plugins = {}
|
||||
for name, plugin in plugins_config.items():
|
||||
if type(plugin) is dict:
|
||||
plugins[name] = PluginConfig(
|
||||
name=name,
|
||||
enabled=plugin.get("enabled", False),
|
||||
config=plugin.get("config", {}),
|
||||
)
|
||||
elif isinstance(plugin, PluginConfig):
|
||||
plugins[name] = plugin
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin config data type: {type(plugin)}")
|
||||
return plugins
|
||||
|
||||
@staticmethod
|
||||
def create_empty_plugins_config(
|
||||
plugins_config_file: Path,
|
||||
plugins_denylist: list[str],
|
||||
plugins_allowlist: list[str],
|
||||
):
|
||||
"""
|
||||
Create an empty plugins_config.yaml file.
|
||||
Fill it with values from old env variables.
|
||||
"""
|
||||
base_config = {}
|
||||
|
||||
logger.debug(f"Legacy plugin denylist: {plugins_denylist}")
|
||||
logger.debug(f"Legacy plugin allowlist: {plugins_allowlist}")
|
||||
|
||||
# Backwards-compatibility shim
|
||||
for plugin_name in plugins_denylist:
|
||||
base_config[plugin_name] = {"enabled": False, "config": {}}
|
||||
|
||||
for plugin_name in plugins_allowlist:
|
||||
base_config[plugin_name] = {"enabled": True, "config": {}}
|
||||
|
||||
logger.debug(f"Constructed base plugins config: {base_config}")
|
||||
|
||||
logger.debug(f"Creating plugin config file {plugins_config_file}")
|
||||
with open(plugins_config_file, "w+") as f:
|
||||
f.write(yaml.dump(base_config))
|
||||
return base_config
|
||||
@@ -1,6 +1,7 @@
|
||||
import functools
|
||||
import re
|
||||
from typing import Any, Callable, ParamSpec, TypeVar
|
||||
from inspect import signature
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
P = ParamSpec("P")
|
||||
@@ -14,32 +15,29 @@ def validate_url(func: Callable[P, T]) -> Callable[P, T]:
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(url: str, *args, **kwargs) -> Any:
|
||||
"""Check if the URL is valid and not a local file accessor.
|
||||
def wrapper(*args, **kwargs):
|
||||
sig = signature(func)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
Args:
|
||||
url (str): The URL to check
|
||||
url = bound_args.arguments.get("url")
|
||||
if url is None:
|
||||
raise ValueError("URL is required for this function")
|
||||
|
||||
Returns:
|
||||
the result of the wrapped function
|
||||
|
||||
Raises:
|
||||
ValueError if the url fails any of the validation tests
|
||||
"""
|
||||
|
||||
# Most basic check if the URL is valid:
|
||||
if not re.match(r"^https?://", url):
|
||||
raise ValueError("Invalid URL format")
|
||||
raise ValueError(
|
||||
"Invalid URL format: URL must start with http:// or https://"
|
||||
)
|
||||
if not is_valid_url(url):
|
||||
raise ValueError("Missing Scheme or Network location")
|
||||
# Restrict access to local files
|
||||
if check_local_file_access(url):
|
||||
raise ValueError("Access to local files is restricted")
|
||||
# Check URL length
|
||||
if len(url) > 2000:
|
||||
raise ValueError("URL is too long")
|
||||
|
||||
return func(sanitize_url(url), *args, **kwargs)
|
||||
bound_args.arguments["url"] = sanitize_url(url)
|
||||
|
||||
return func(*bound_args.args, **bound_args.kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -36,10 +36,6 @@ class UnknownCommandError(AgentException):
|
||||
hint = "Do not try to use this command again."
|
||||
|
||||
|
||||
class DuplicateOperationError(AgentException):
|
||||
"""The proposed operation has already been executed"""
|
||||
|
||||
|
||||
class CommandExecutionError(AgentException):
|
||||
"""An error occurred when trying to execute the command"""
|
||||
|
||||
31
autogpts/autogpt/autogpt/utils/retry_decorator.py
Normal file
31
autogpts/autogpt/autogpt/utils/retry_decorator.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
import sentry_sdk
|
||||
|
||||
|
||||
def retry(retry_count: int = 3, pass_exception: str = "exception"):
|
||||
"""Decorator to retry a function multiple times on failure.
|
||||
Can pass the exception to the function as a keyword argument."""
|
||||
|
||||
def decorator(func):
|
||||
params = inspect.signature(func).parameters
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
exception: Optional[Exception] = None
|
||||
attempts = 0
|
||||
while attempts < retry_count:
|
||||
try:
|
||||
if pass_exception in params:
|
||||
kwargs[pass_exception] = exception
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
attempts += 1
|
||||
exception = e
|
||||
sentry_sdk.capture_exception(e)
|
||||
if attempts >= retry_count:
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -3,6 +3,9 @@ from pathlib import Path
|
||||
import yaml
|
||||
from colorama import Fore
|
||||
|
||||
DEFAULT_FINISH_COMMAND = "finish"
|
||||
DEFAULT_ASK_COMMAND = "ask_user"
|
||||
|
||||
|
||||
def validate_yaml_file(file: str | Path):
|
||||
try:
|
||||
3757
autogpts/autogpt/poetry.lock
generated
3757
autogpts/autogpt/poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -4,8 +4,6 @@ constraints: [
|
||||
'You are unable to interact with physical objects. If this is absolutely necessary to fulfill a task or objective or to complete a step, you must ask the user to do it for you. If the user refuses this, and there is no other way to achieve your goals, you must terminate to avoid wasting time and energy.'
|
||||
]
|
||||
resources: [
|
||||
'Internet access for searches and information gathering.',
|
||||
'The ability to read and write files.',
|
||||
'You are a Large Language Model, trained on millions of pages of text, including a lot of factual knowledge. Make use of this factual knowledge to avoid unnecessary gathering of information.'
|
||||
]
|
||||
best_practices: [
|
||||
|
||||
@@ -22,9 +22,9 @@ serve = "autogpt.app.cli:serve"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
auto-gpt-plugin-template = {git = "https://github.com/Significant-Gravitas/Auto-GPT-Plugin-Template", rev = "0.1.0"}
|
||||
anthropic = "^0.25.1"
|
||||
# autogpt-forge = { path = "../forge" }
|
||||
autogpt-forge = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", rev = "ab05b7ae70754c063909", subdirectory = "autogpts/forge"}
|
||||
autogpt-forge = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "autogpts/forge"}
|
||||
beautifulsoup4 = "^4.12.2"
|
||||
boto3 = "^1.33.6"
|
||||
charset-normalizer = "^3.1.0"
|
||||
@@ -64,6 +64,8 @@ spacy = "^3.0.0"
|
||||
tenacity = "^8.2.2"
|
||||
tiktoken = "^0.5.0"
|
||||
webdriver-manager = "*"
|
||||
click-default-group = "^1.2.4"
|
||||
|
||||
|
||||
# OpenAI and Generic plugins import
|
||||
openapi-python-client = "^0.14.0"
|
||||
|
||||
141
autogpts/autogpt/scripts/git_log_to_release_notes.py
Executable file
141
autogpts/autogpt/scripts/git_log_to_release_notes.py
Executable file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from git import Repo, TagReference
|
||||
|
||||
from autogpt.core.resource.model_providers import ChatMessage, MultiProvider
|
||||
from autogpt.core.resource.model_providers.anthropic import AnthropicModelName
|
||||
from autogpt.core.runner.client_lib.utils import coroutine
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--repo-path",
|
||||
type=click.Path(file_okay=False, exists=True),
|
||||
help="Path to the git repository",
|
||||
)
|
||||
@coroutine
|
||||
async def generate_release_notes(repo_path: Optional[Path] = None):
|
||||
logger = logging.getLogger(generate_release_notes.name)
|
||||
|
||||
repo = Repo(repo_path, search_parent_directories=True)
|
||||
tags = list(repo.tags)
|
||||
if not tags:
|
||||
click.echo("No tags found in the repository.")
|
||||
return
|
||||
|
||||
click.echo("Available tags:")
|
||||
for index, tag in enumerate(tags):
|
||||
click.echo(f"{index + 1}: {tag.name}")
|
||||
|
||||
last_release_index = (
|
||||
click.prompt("Enter the number for the last release tag", type=int) - 1
|
||||
)
|
||||
if last_release_index >= len(tags) or last_release_index < 0:
|
||||
click.echo("Invalid tag number entered.")
|
||||
return
|
||||
last_release_tag: TagReference = tags[last_release_index]
|
||||
|
||||
new_release_ref = click.prompt(
|
||||
"Enter the name of the release branch or git ref",
|
||||
default=repo.active_branch.name,
|
||||
)
|
||||
try:
|
||||
new_release_ref = repo.heads[new_release_ref].name
|
||||
except IndexError:
|
||||
try:
|
||||
new_release_ref = repo.tags[new_release_ref].name
|
||||
except IndexError:
|
||||
new_release_ref = repo.commit(new_release_ref).hexsha
|
||||
logger.debug(f"Selected release ref: {new_release_ref}")
|
||||
|
||||
git_log = repo.git.log(
|
||||
f"{last_release_tag.name}...{new_release_ref}",
|
||||
"autogpts/autogpt/",
|
||||
no_merges=True,
|
||||
follow=True,
|
||||
)
|
||||
logger.debug(f"-------------- GIT LOG --------------\n\n{git_log}\n")
|
||||
|
||||
model_provider = MultiProvider()
|
||||
chat_messages = [
|
||||
ChatMessage.system(SYSTEM_PROMPT),
|
||||
ChatMessage.user(content=git_log),
|
||||
]
|
||||
click.echo("Writing release notes ...")
|
||||
completion = await model_provider.create_chat_completion(
|
||||
model_prompt=chat_messages,
|
||||
model_name=AnthropicModelName.CLAUDE3_OPUS_v1,
|
||||
# model_name=OpenAIModelName.GPT4_v4,
|
||||
)
|
||||
|
||||
click.echo("-------------- LLM RESPONSE --------------\n")
|
||||
click.echo(completion.response.content)
|
||||
|
||||
|
||||
EXAMPLE_RELEASE_NOTES = """
|
||||
First some important notes w.r.t. using the application:
|
||||
* `run.sh` has been renamed to `autogpt.sh`
|
||||
* The project has been restructured. The AutoGPT Agent is now located in `autogpts/autogpt`.
|
||||
* The application no longer uses a single workspace for all tasks. Instead, every task that you run the agent on creates a new workspace folder. See the [usage guide](https://docs.agpt.co/autogpt/usage/#workspace) for more information.
|
||||
|
||||
## New features ✨
|
||||
|
||||
* **Agent Protocol 🔌**
|
||||
Our agent now works with the [Agent Protocol](/#-agent-protocol), a REST API that allows creating tasks and executing the agent's step-by-step process. This allows integration with other applications, and we also use it to connect to the agent through the UI.
|
||||
* **UI 💻**
|
||||
With the aforementioned Agent Protocol integration comes the benefit of using our own open-source Agent UI. Easily create, use, and chat with multiple agents from one interface.
|
||||
When starting the application through the project's new [CLI](/#-cli), it runs with the new frontend by default, with benchmarking capabilities. Running `autogpt.sh serve` in the subproject folder (`autogpts/autogpt`) will also serve the new frontend, but without benchmarking functionality.
|
||||
Running the application the "old-fashioned" way, with the terminal interface (let's call it TTY mode), is still possible with `autogpt.sh run`.
|
||||
* **Resuming agents 🔄️**
|
||||
In TTY mode, the application will now save the agent's state when quitting, and allows resuming where you left off at a later time!
|
||||
* **GCS and S3 workspace backends 📦**
|
||||
To further support running the application as part of a larger system, Google Cloud Storage and S3 workspace backends were added. Configuration options for this can be found in [`.env.template`](/autogpts/autogpt/.env.template).
|
||||
* **Documentation Rewrite 📖**
|
||||
The [documentation](https://docs.agpt.co) has been restructured and mostly rewritten to clarify and simplify the instructions, and also to accommodate the other subprojects that are now in the repo.
|
||||
* **New Project CLI 🔧**
|
||||
The project has a new CLI to provide easier usage of all of the components that are now in the repo: different agents, frontend and benchmark. More info can be found [here](/#-cli).
|
||||
* **Docker dev build 🐳**
|
||||
In addition to the regular Docker release [images](https://hub.docker.com/r/significantgravitas/auto-gpt/tags) (`latest`, `v0.5.0` in this case), we now also publish a `latest-dev` image that always contains the latest working build from `master`. This allows you to try out the latest bleeding edge version, but be aware that these builds may contain bugs!
|
||||
|
||||
## Architecture changes & improvements 👷🏼
|
||||
* **PromptStrategy**
|
||||
To make it easier to harness the power of LLMs and use them to fulfil tasks within the application, we adopted the `PromptStrategy` class from `autogpt.core` (AKA re-arch) to encapsulate prompt generation and response parsing throughout the application.
|
||||
* **Config modularization**
|
||||
To reduce the complexity of the application's config structure, parts of the monolithic `Config` have been moved into smaller, tightly scoped config objects. Also, the logic for building the configuration from environment variables was decentralized to make it all a lot more maintainable.
|
||||
This is mostly made possible by the `autogpt.core.configuration` module, which was also expanded with a few new features for it. Most notably, the new `from_env` attribute on the `UserConfigurable` field decorator and corresponding logic in `SystemConfiguration.from_env()` and related functions.
|
||||
* **Monorepo**
|
||||
As mentioned, the repo has been restructured to accommodate the AutoGPT Agent, Forge, AGBenchmark and the new Frontend.
|
||||
* AutoGPT Agent has been moved to `autogpts/autogpt`
|
||||
* Forge now lives in `autogpts/forge`, and the project's new CLI makes it easy to create new Forge-based agents.
|
||||
* AGBenchmark -> `benchmark`
|
||||
* Frontend -> `frontend`
|
||||
|
||||
See also the [README](/#readme).
|
||||
""".lstrip() # noqa
|
||||
|
||||
|
||||
SYSTEM_PROMPT = f"""
|
||||
Please generate release notes based on the user's git log and the example release notes.
|
||||
|
||||
Here is an example of what we like our release notes to look and read like:
|
||||
---------------------------------------------------------------------------
|
||||
{EXAMPLE_RELEASE_NOTES}
|
||||
---------------------------------------------------------------------------
|
||||
NOTE: These example release notes are not related to the git log that you should write release notes for!
|
||||
Do not mention the changes in the example when writing your release notes!
|
||||
""".lstrip() # noqa
|
||||
|
||||
if __name__ == "__main__":
|
||||
import dotenv
|
||||
|
||||
from autogpt.logs.config import configure_logging
|
||||
|
||||
configure_logging(debug=True)
|
||||
|
||||
dotenv.load_dotenv()
|
||||
generate_release_notes()
|
||||
@@ -1,66 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import zipfile
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def install_plugin_dependencies():
|
||||
"""
|
||||
Installs dependencies for all plugins in the plugins dir.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
plugins_dir = Path(os.getenv("PLUGINS_DIR", "plugins"))
|
||||
|
||||
logger.debug("Checking for dependencies in zipped plugins...")
|
||||
|
||||
# Install zip-based plugins
|
||||
for plugin_archive in plugins_dir.glob("*.zip"):
|
||||
logger.debug(f"Checking for requirements in '{plugin_archive}'...")
|
||||
with zipfile.ZipFile(str(plugin_archive), "r") as zfile:
|
||||
if not zfile.namelist():
|
||||
continue
|
||||
|
||||
# Assume the first entry in the list will be (in) the lowest common dir
|
||||
first_entry = zfile.namelist()[0]
|
||||
basedir = first_entry.rsplit("/", 1)[0] if "/" in first_entry else ""
|
||||
logger.debug(f"Looking for requirements.txt in '{basedir}'")
|
||||
|
||||
basereqs = os.path.join(basedir, "requirements.txt")
|
||||
try:
|
||||
extracted = zfile.extract(basereqs, path=plugins_dir)
|
||||
except KeyError as e:
|
||||
logger.debug(e.args[0])
|
||||
continue
|
||||
|
||||
logger.debug(f"Installing dependencies from '{basereqs}'...")
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "-r", extracted]
|
||||
)
|
||||
os.remove(extracted)
|
||||
os.rmdir(os.path.join(plugins_dir, basedir))
|
||||
|
||||
logger.debug("Checking for dependencies in other plugin folders...")
|
||||
|
||||
# Install directory-based plugins
|
||||
for requirements_file in glob(f"{plugins_dir}/*/requirements.txt"):
|
||||
logger.debug(f"Installing dependencies from '{requirements_file}'...")
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "-r", requirements_file],
|
||||
stdout=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
logger.debug("Finished installing plugin dependencies")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
install_plugin_dependencies()
|
||||
@@ -3,23 +3,20 @@ from __future__ import annotations
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.app.main import _configure_openai_provider
|
||||
from autogpt.app.main import _configure_llm_provider
|
||||
from autogpt.config import AIProfile, Config, ConfigBuilder
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider, OpenAIProvider
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||
from autogpt.file_storage.local import (
|
||||
FileStorage,
|
||||
FileStorageConfiguration,
|
||||
LocalFileStorage,
|
||||
)
|
||||
from autogpt.logs.config import configure_logging
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
pytest_plugins = [
|
||||
"tests.integration.agent_factory",
|
||||
@@ -49,23 +46,8 @@ def storage(app_data_dir: Path) -> FileStorage:
|
||||
return storage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_plugins_config_file():
|
||||
"""
|
||||
Create a plugins_config.yaml file in a temp directory
|
||||
so that it doesn't mess with existing ones.
|
||||
"""
|
||||
config_directory = TemporaryDirectory()
|
||||
config_file = Path(config_directory.name) / "plugins_config.yaml"
|
||||
with open(config_file, "w+") as f:
|
||||
f.write(yaml.dump({}))
|
||||
|
||||
yield config_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def config(
|
||||
temp_plugins_config_file: Path,
|
||||
tmp_project_root: Path,
|
||||
app_data_dir: Path,
|
||||
mocker: MockerFixture,
|
||||
@@ -76,19 +58,8 @@ def config(
|
||||
|
||||
config.app_data_dir = app_data_dir
|
||||
|
||||
config.plugins_dir = "tests/unit/data/test_plugins"
|
||||
config.plugins_config_file = temp_plugins_config_file
|
||||
|
||||
config.noninteractive_mode = True
|
||||
|
||||
# avoid circular dependency
|
||||
from autogpt.plugins.plugins_config import PluginsConfig
|
||||
|
||||
config.plugins_config = PluginsConfig.load_config(
|
||||
plugins_config_file=config.plugins_config_file,
|
||||
plugins_denylist=config.plugins_denylist,
|
||||
plugins_allowlist=config.plugins_allowlist,
|
||||
)
|
||||
yield config
|
||||
|
||||
|
||||
@@ -102,8 +73,8 @@ def setup_logger(config: Config):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(config: Config) -> OpenAIProvider:
|
||||
return _configure_openai_provider(config)
|
||||
def llm_provider(config: Config) -> ChatModelProvider:
|
||||
return _configure_llm_provider(config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -116,11 +87,6 @@ def agent(
|
||||
ai_goals=[],
|
||||
)
|
||||
|
||||
command_registry = CommandRegistry()
|
||||
|
||||
agent_prompt_config = Agent.default_settings.prompt_config.copy(deep=True)
|
||||
agent_prompt_config.use_functions_api = config.openai_functions
|
||||
|
||||
agent_settings = AgentSettings(
|
||||
name=Agent.default_settings.name,
|
||||
description=Agent.default_settings.description,
|
||||
@@ -131,16 +97,13 @@ def agent(
|
||||
smart_llm=config.smart_llm,
|
||||
allow_fs_access=not config.restrict_to_workspace,
|
||||
use_functions_api=config.openai_functions,
|
||||
plugins=config.plugins,
|
||||
),
|
||||
prompt_config=agent_prompt_config,
|
||||
history=Agent.default_settings.history.copy(deep=True),
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
settings=agent_settings,
|
||||
llm_provider=llm_provider,
|
||||
command_registry=command_registry,
|
||||
file_storage=storage,
|
||||
legacy_config=config,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.agents.prompt_strategies.one_shot import OneShotAgentPromptStrategy
|
||||
from autogpt.config import AIProfile, Config
|
||||
from autogpt.file_storage import FileStorageBackendName, get_storage
|
||||
from autogpt.memory.vector import get_memory
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -20,8 +21,6 @@ def memory_json_file(config: Config):
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_agent(config: Config, llm_provider, memory_json_file):
|
||||
command_registry = CommandRegistry()
|
||||
|
||||
ai_profile = AIProfile(
|
||||
ai_name="Dummy Agent",
|
||||
ai_role="Dummy Role",
|
||||
@@ -30,7 +29,9 @@ def dummy_agent(config: Config, llm_provider, memory_json_file):
|
||||
],
|
||||
)
|
||||
|
||||
agent_prompt_config = Agent.default_settings.prompt_config.copy(deep=True)
|
||||
agent_prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(
|
||||
deep=True
|
||||
)
|
||||
agent_prompt_config.use_functions_api = config.openai_functions
|
||||
agent_settings = AgentSettings(
|
||||
name=Agent.default_settings.name,
|
||||
@@ -40,16 +41,22 @@ def dummy_agent(config: Config, llm_provider, memory_json_file):
|
||||
fast_llm=config.fast_llm,
|
||||
smart_llm=config.smart_llm,
|
||||
use_functions_api=config.openai_functions,
|
||||
plugins=config.plugins,
|
||||
),
|
||||
prompt_config=agent_prompt_config,
|
||||
history=Agent.default_settings.history.copy(deep=True),
|
||||
)
|
||||
|
||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||
restrict_to_root = not local or config.restrict_to_workspace
|
||||
file_storage = get_storage(
|
||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
||||
)
|
||||
file_storage.initialize()
|
||||
|
||||
agent = Agent(
|
||||
settings=agent_settings,
|
||||
llm_provider=llm_provider,
|
||||
command_registry=command_registry,
|
||||
file_storage=file_storage,
|
||||
legacy_config=config,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,12 +5,19 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import autogpt.commands.execute_code as sut # system under testing
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import (
|
||||
InvalidArgumentError,
|
||||
OperationNotAllowedError,
|
||||
from autogpt.commands.execute_code import (
|
||||
ALLOWLIST_CONTROL,
|
||||
CodeExecutorComponent,
|
||||
is_docker_available,
|
||||
we_are_running_in_a_docker_container,
|
||||
)
|
||||
from autogpt.utils.exceptions import InvalidArgumentError, OperationNotAllowedError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def code_executor_component(agent: Agent):
|
||||
return agent.code_executor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -20,7 +27,9 @@ def random_code(random_string) -> str:
|
||||
|
||||
@pytest.fixture
|
||||
def python_test_file(agent: Agent, random_code: str):
|
||||
temp_file = tempfile.NamedTemporaryFile(dir=agent.workspace.root, suffix=".py")
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
dir=agent.file_manager.workspace.root, suffix=".py"
|
||||
)
|
||||
temp_file.write(str.encode(random_code))
|
||||
temp_file.flush()
|
||||
|
||||
@@ -30,7 +39,9 @@ def python_test_file(agent: Agent, random_code: str):
|
||||
|
||||
@pytest.fixture
|
||||
def python_test_args_file(agent: Agent):
|
||||
temp_file = tempfile.NamedTemporaryFile(dir=agent.workspace.root, suffix=".py")
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
dir=agent.file_manager.workspace.root, suffix=".py"
|
||||
)
|
||||
temp_file.write(str.encode("import sys\nprint(sys.argv[1], sys.argv[2])"))
|
||||
temp_file.flush()
|
||||
|
||||
@@ -43,85 +54,114 @@ def random_string():
|
||||
return "".join(random.choice(string.ascii_lowercase) for _ in range(10))
|
||||
|
||||
|
||||
def test_execute_python_file(python_test_file: Path, random_string: str, agent: Agent):
|
||||
if not (sut.is_docker_available() or sut.we_are_running_in_a_docker_container()):
|
||||
def test_execute_python_file(
|
||||
code_executor_component: CodeExecutorComponent,
|
||||
python_test_file: Path,
|
||||
random_string: str,
|
||||
agent: Agent,
|
||||
):
|
||||
if not (is_docker_available() or we_are_running_in_a_docker_container()):
|
||||
pytest.skip("Docker is not available")
|
||||
|
||||
result: str = sut.execute_python_file(python_test_file, agent=agent)
|
||||
result: str = code_executor_component.execute_python_file(python_test_file)
|
||||
assert result.replace("\r", "") == f"Hello {random_string}!\n"
|
||||
|
||||
|
||||
def test_execute_python_file_args(
|
||||
python_test_args_file: Path, random_string: str, agent: Agent
|
||||
code_executor_component: CodeExecutorComponent,
|
||||
python_test_args_file: Path,
|
||||
random_string: str,
|
||||
agent: Agent,
|
||||
):
|
||||
if not (sut.is_docker_available() or sut.we_are_running_in_a_docker_container()):
|
||||
if not (is_docker_available() or we_are_running_in_a_docker_container()):
|
||||
pytest.skip("Docker is not available")
|
||||
|
||||
random_args = [random_string] * 2
|
||||
random_args_string = " ".join(random_args)
|
||||
result = sut.execute_python_file(
|
||||
python_test_args_file, args=random_args, agent=agent
|
||||
result = code_executor_component.execute_python_file(
|
||||
python_test_args_file, args=random_args
|
||||
)
|
||||
assert result == f"{random_args_string}\n"
|
||||
|
||||
|
||||
def test_execute_python_code(random_code: str, random_string: str, agent: Agent):
|
||||
if not (sut.is_docker_available() or sut.we_are_running_in_a_docker_container()):
|
||||
def test_execute_python_code(
|
||||
code_executor_component: CodeExecutorComponent,
|
||||
random_code: str,
|
||||
random_string: str,
|
||||
agent: Agent,
|
||||
):
|
||||
if not (is_docker_available() or we_are_running_in_a_docker_container()):
|
||||
pytest.skip("Docker is not available")
|
||||
|
||||
result: str = sut.execute_python_code(random_code, agent=agent)
|
||||
result: str = code_executor_component.execute_python_code(random_code)
|
||||
assert result.replace("\r", "") == f"Hello {random_string}!\n"
|
||||
|
||||
|
||||
def test_execute_python_file_invalid(agent: Agent):
|
||||
def test_execute_python_file_invalid(
|
||||
code_executor_component: CodeExecutorComponent, agent: Agent
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
sut.execute_python_file(Path("not_python.txt"), agent)
|
||||
code_executor_component.execute_python_file(Path("not_python.txt"))
|
||||
|
||||
|
||||
def test_execute_python_file_not_found(agent: Agent):
|
||||
def test_execute_python_file_not_found(
|
||||
code_executor_component: CodeExecutorComponent, agent: Agent
|
||||
):
|
||||
with pytest.raises(
|
||||
FileNotFoundError,
|
||||
match=r"python: can't open file '([a-zA-Z]:)?[/\\\-\w]*notexist.py': "
|
||||
r"\[Errno 2\] No such file or directory",
|
||||
):
|
||||
sut.execute_python_file(Path("notexist.py"), agent)
|
||||
code_executor_component.execute_python_file(Path("notexist.py"))
|
||||
|
||||
|
||||
def test_execute_shell(random_string: str, agent: Agent):
|
||||
result = sut.execute_shell(f"echo 'Hello {random_string}!'", agent)
|
||||
def test_execute_shell(
|
||||
code_executor_component: CodeExecutorComponent, random_string: str, agent: Agent
|
||||
):
|
||||
result = code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
assert f"Hello {random_string}!" in result
|
||||
|
||||
|
||||
def test_execute_shell_local_commands_not_allowed(random_string: str, agent: Agent):
|
||||
result = sut.execute_shell(f"echo 'Hello {random_string}!'", agent)
|
||||
def test_execute_shell_local_commands_not_allowed(
|
||||
code_executor_component: CodeExecutorComponent, random_string: str, agent: Agent
|
||||
):
|
||||
result = code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
assert f"Hello {random_string}!" in result
|
||||
|
||||
|
||||
def test_execute_shell_denylist_should_deny(agent: Agent, random_string: str):
|
||||
def test_execute_shell_denylist_should_deny(
|
||||
code_executor_component: CodeExecutorComponent, agent: Agent, random_string: str
|
||||
):
|
||||
agent.legacy_config.shell_denylist = ["echo"]
|
||||
|
||||
with pytest.raises(OperationNotAllowedError, match="not allowed"):
|
||||
sut.execute_shell(f"echo 'Hello {random_string}!'", agent)
|
||||
code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
|
||||
|
||||
def test_execute_shell_denylist_should_allow(agent: Agent, random_string: str):
|
||||
def test_execute_shell_denylist_should_allow(
|
||||
code_executor_component: CodeExecutorComponent, agent: Agent, random_string: str
|
||||
):
|
||||
agent.legacy_config.shell_denylist = ["cat"]
|
||||
|
||||
result = sut.execute_shell(f"echo 'Hello {random_string}!'", agent)
|
||||
result = code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
assert "Hello" in result and random_string in result
|
||||
|
||||
|
||||
def test_execute_shell_allowlist_should_deny(agent: Agent, random_string: str):
|
||||
agent.legacy_config.shell_command_control = sut.ALLOWLIST_CONTROL
|
||||
def test_execute_shell_allowlist_should_deny(
|
||||
code_executor_component: CodeExecutorComponent, agent: Agent, random_string: str
|
||||
):
|
||||
agent.legacy_config.shell_command_control = ALLOWLIST_CONTROL
|
||||
agent.legacy_config.shell_allowlist = ["cat"]
|
||||
|
||||
with pytest.raises(OperationNotAllowedError, match="not allowed"):
|
||||
sut.execute_shell(f"echo 'Hello {random_string}!'", agent)
|
||||
code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
|
||||
|
||||
def test_execute_shell_allowlist_should_allow(agent: Agent, random_string: str):
|
||||
agent.legacy_config.shell_command_control = sut.ALLOWLIST_CONTROL
|
||||
def test_execute_shell_allowlist_should_allow(
|
||||
code_executor_component: CodeExecutorComponent, agent: Agent, random_string: str
|
||||
):
|
||||
agent.legacy_config.shell_command_control = ALLOWLIST_CONTROL
|
||||
agent.legacy_config.shell_allowlist = ["echo"]
|
||||
|
||||
result = sut.execute_shell(f"echo 'Hello {random_string}!'", agent)
|
||||
result = code_executor_component.execute_shell(f"echo 'Hello {random_string}!'")
|
||||
assert "Hello" in result and random_string in result
|
||||
|
||||
@@ -7,7 +7,12 @@ import pytest
|
||||
from PIL import Image
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui
|
||||
from autogpt.commands.image_gen import ImageGeneratorComponent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_gen_component(agent: Agent):
|
||||
return agent.image_gen
|
||||
|
||||
|
||||
@pytest.fixture(params=[256, 512, 1024])
|
||||
@@ -18,9 +23,16 @@ def image_size(request):
|
||||
|
||||
@pytest.mark.requires_openai_api_key
|
||||
@pytest.mark.vcr
|
||||
def test_dalle(agent: Agent, storage, image_size, cached_openai_client):
|
||||
def test_dalle(
|
||||
image_gen_component: ImageGeneratorComponent,
|
||||
agent: Agent,
|
||||
storage,
|
||||
image_size,
|
||||
cached_openai_client,
|
||||
):
|
||||
"""Test DALL-E image generation."""
|
||||
generate_and_validate(
|
||||
image_gen_component,
|
||||
agent,
|
||||
storage,
|
||||
image_provider="dalle",
|
||||
@@ -37,9 +49,16 @@ def test_dalle(agent: Agent, storage, image_size, cached_openai_client):
|
||||
"image_model",
|
||||
["CompVis/stable-diffusion-v1-4", "stabilityai/stable-diffusion-2-1"],
|
||||
)
|
||||
def test_huggingface(agent: Agent, storage, image_size, image_model):
|
||||
def test_huggingface(
|
||||
image_gen_component: ImageGeneratorComponent,
|
||||
agent: Agent,
|
||||
storage,
|
||||
image_size,
|
||||
image_model,
|
||||
):
|
||||
"""Test HuggingFace image generation."""
|
||||
generate_and_validate(
|
||||
image_gen_component,
|
||||
agent,
|
||||
storage,
|
||||
image_provider="huggingface",
|
||||
@@ -49,9 +68,12 @@ def test_huggingface(agent: Agent, storage, image_size, image_model):
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="SD WebUI call does not work.")
|
||||
def test_sd_webui(agent: Agent, storage, image_size):
|
||||
def test_sd_webui(
|
||||
image_gen_component: ImageGeneratorComponent, agent: Agent, storage, image_size
|
||||
):
|
||||
"""Test SD WebUI image generation."""
|
||||
generate_and_validate(
|
||||
image_gen_component,
|
||||
agent,
|
||||
storage,
|
||||
image_provider="sd_webui",
|
||||
@@ -60,11 +82,12 @@ def test_sd_webui(agent: Agent, storage, image_size):
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="SD WebUI call does not work.")
|
||||
def test_sd_webui_negative_prompt(agent: Agent, storage, image_size):
|
||||
def test_sd_webui_negative_prompt(
|
||||
image_gen_component: ImageGeneratorComponent, storage, image_size
|
||||
):
|
||||
gen_image = functools.partial(
|
||||
generate_image_with_sd_webui,
|
||||
image_gen_component.generate_image_with_sd_webui,
|
||||
prompt="astronaut riding a horse",
|
||||
agent=agent,
|
||||
size=image_size,
|
||||
extra={"seed": 123},
|
||||
)
|
||||
@@ -90,6 +113,7 @@ def lst(txt):
|
||||
|
||||
|
||||
def generate_and_validate(
|
||||
image_gen_component: ImageGeneratorComponent,
|
||||
agent: Agent,
|
||||
storage,
|
||||
image_size,
|
||||
@@ -103,7 +127,7 @@ def generate_and_validate(
|
||||
agent.legacy_config.huggingface_image_model = hugging_face_image_model
|
||||
prompt = "astronaut riding a horse"
|
||||
|
||||
image_path = lst(generate_image(prompt, agent, image_size, **kwargs))
|
||||
image_path = lst(image_gen_component.generate_image(prompt, image_size, **kwargs))
|
||||
assert image_path.exists()
|
||||
with Image.open(image_path) as img:
|
||||
assert img.size == (image_size, image_size)
|
||||
@@ -125,7 +149,13 @@ def generate_and_validate(
|
||||
)
|
||||
@pytest.mark.parametrize("delay", [10, 0])
|
||||
def test_huggingface_fail_request_with_delay(
|
||||
agent: Agent, storage, image_size, image_model, return_text, delay
|
||||
image_gen_component: ImageGeneratorComponent,
|
||||
agent: Agent,
|
||||
storage,
|
||||
image_size,
|
||||
image_model,
|
||||
return_text,
|
||||
delay,
|
||||
):
|
||||
return_text = return_text.replace("[model]", image_model).replace(
|
||||
"[delay]", str(delay)
|
||||
@@ -150,7 +180,7 @@ def test_huggingface_fail_request_with_delay(
|
||||
|
||||
with patch("time.sleep") as mock_sleep:
|
||||
# Verify request fails.
|
||||
result = generate_image(prompt, agent, image_size)
|
||||
result = image_gen_component.generate_image(prompt, image_size)
|
||||
assert result == "Error creating image."
|
||||
|
||||
# Verify retry was called with delay if delay is in return_text
|
||||
@@ -160,7 +190,9 @@ def test_huggingface_fail_request_with_delay(
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
def test_huggingface_fail_request_no_delay(mocker, agent: Agent):
|
||||
def test_huggingface_fail_request_no_delay(
|
||||
mocker, image_gen_component: ImageGeneratorComponent, agent: Agent
|
||||
):
|
||||
agent.legacy_config.huggingface_api_token = "1"
|
||||
|
||||
# Mock requests.post
|
||||
@@ -177,7 +209,7 @@ def test_huggingface_fail_request_no_delay(mocker, agent: Agent):
|
||||
agent.legacy_config.image_provider = "huggingface"
|
||||
agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
result = generate_image("astronaut riding a horse", agent, 512)
|
||||
result = image_gen_component.generate_image("astronaut riding a horse", 512)
|
||||
|
||||
assert result == "Error creating image."
|
||||
|
||||
@@ -185,7 +217,9 @@ def test_huggingface_fail_request_no_delay(mocker, agent: Agent):
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
def test_huggingface_fail_request_bad_json(mocker, agent: Agent):
|
||||
def test_huggingface_fail_request_bad_json(
|
||||
mocker, image_gen_component: ImageGeneratorComponent, agent: Agent
|
||||
):
|
||||
agent.legacy_config.huggingface_api_token = "1"
|
||||
|
||||
# Mock requests.post
|
||||
@@ -200,7 +234,7 @@ def test_huggingface_fail_request_bad_json(mocker, agent: Agent):
|
||||
agent.legacy_config.image_provider = "huggingface"
|
||||
agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
result = generate_image("astronaut riding a horse", agent, 512)
|
||||
result = image_gen_component.generate_image("astronaut riding a horse", 512)
|
||||
|
||||
assert result == "Error creating image."
|
||||
|
||||
@@ -208,7 +242,9 @@ def test_huggingface_fail_request_bad_json(mocker, agent: Agent):
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
def test_huggingface_fail_request_bad_image(mocker, agent: Agent):
|
||||
def test_huggingface_fail_request_bad_image(
|
||||
mocker, image_gen_component: ImageGeneratorComponent, agent: Agent
|
||||
):
|
||||
agent.legacy_config.huggingface_api_token = "1"
|
||||
|
||||
# Mock requests.post
|
||||
@@ -218,12 +254,14 @@ def test_huggingface_fail_request_bad_image(mocker, agent: Agent):
|
||||
agent.legacy_config.image_provider = "huggingface"
|
||||
agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
result = generate_image("astronaut riding a horse", agent, 512)
|
||||
result = image_gen_component.generate_image("astronaut riding a horse", 512)
|
||||
|
||||
assert result == "Error creating image."
|
||||
|
||||
|
||||
def test_huggingface_fail_missing_api_token(mocker, agent: Agent):
|
||||
def test_huggingface_fail_missing_api_token(
|
||||
mocker, image_gen_component: ImageGeneratorComponent, agent: Agent
|
||||
):
|
||||
agent.legacy_config.image_provider = "huggingface"
|
||||
agent.legacy_config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
@@ -232,4 +270,4 @@ def test_huggingface_fail_missing_api_token(mocker, agent: Agent):
|
||||
|
||||
# Verify request raises an error.
|
||||
with pytest.raises(ValueError):
|
||||
generate_image("astronaut riding a horse", agent, 512)
|
||||
image_gen_component.generate_image("astronaut riding a horse", 512)
|
||||
|
||||
@@ -1,18 +1,25 @@
|
||||
import pytest
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.commands.web_selenium import BrowsingError, read_webpage
|
||||
from autogpt.commands.web_selenium import BrowsingError, WebSeleniumComponent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def web_selenium_component(agent: Agent):
|
||||
return agent.web_selenium
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
@pytest.mark.requires_openai_api_key
|
||||
@pytest.mark.asyncio
|
||||
async def test_browse_website_nonexistent_url(agent: Agent, cached_openai_client: None):
|
||||
async def test_browse_website_nonexistent_url(
|
||||
web_selenium_component: WebSeleniumComponent, cached_openai_client: None
|
||||
):
|
||||
url = "https://auto-gpt-thinks-this-website-does-not-exist.com"
|
||||
question = "How to execute a barrel roll"
|
||||
|
||||
with pytest.raises(BrowsingError, match="NAME_NOT_RESOLVED") as raised:
|
||||
await read_webpage(url=url, question=question, agent=agent)
|
||||
await web_selenium_component.read_webpage(url=url, question=question)
|
||||
|
||||
# Sanity check that the response is not too long
|
||||
assert len(raised.exconly()) < 200
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
COMMAND_CATEGORY = "mock"
|
||||
|
||||
|
||||
@command(
|
||||
"function_based_cmd",
|
||||
"Function-based test command",
|
||||
{
|
||||
"arg1": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="arg 1",
|
||||
required=True,
|
||||
),
|
||||
"arg2": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="arg 2",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
def function_based_cmd(arg1: int, arg2: str) -> str:
|
||||
"""A function-based test command.
|
||||
|
||||
Returns:
|
||||
str: the two arguments separated by a dash.
|
||||
"""
|
||||
return f"{arg1} - {arg2}"
|
||||
Binary file not shown.
@@ -1,274 +0,0 @@
|
||||
"""This is the Test plugin for AutoGPT."""
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
PromptGenerator = TypeVar("PromptGenerator")
|
||||
|
||||
|
||||
class AutoGPTGuanaco(AutoGPTPluginTemplate):
|
||||
"""
|
||||
This is plugin for AutoGPT.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._name = "AutoGPT-Guanaco"
|
||||
self._version = "0.1.0"
|
||||
self._description = "This is a Guanaco local model plugin."
|
||||
|
||||
def can_handle_on_response(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the on_response method.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the on_response method."""
|
||||
return False
|
||||
|
||||
def on_response(self, response: str, *args, **kwargs) -> str:
|
||||
"""This method is called when a response is received from the model."""
|
||||
if len(response):
|
||||
print("OMG OMG It's Alive!")
|
||||
else:
|
||||
print("Is it alive?")
|
||||
|
||||
def can_handle_post_prompt(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_prompt method.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_prompt method."""
|
||||
return False
|
||||
|
||||
def post_prompt(self, prompt: PromptGenerator) -> PromptGenerator:
|
||||
"""This method is called just after the generate_prompt is called,
|
||||
but actually before the prompt is generated.
|
||||
|
||||
Args:
|
||||
prompt (PromptGenerator): The prompt generator.
|
||||
|
||||
Returns:
|
||||
PromptGenerator: The prompt generator.
|
||||
"""
|
||||
|
||||
def can_handle_on_planning(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the on_planning method.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the on_planning method."""
|
||||
return False
|
||||
|
||||
def on_planning(
|
||||
self, prompt: PromptGenerator, messages: List[str]
|
||||
) -> Optional[str]:
|
||||
"""This method is called before the planning chat completeion is done.
|
||||
|
||||
Args:
|
||||
prompt (PromptGenerator): The prompt generator.
|
||||
messages (List[str]): The list of messages.
|
||||
"""
|
||||
|
||||
def can_handle_post_planning(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_planning method.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_planning method."""
|
||||
return False
|
||||
|
||||
def post_planning(self, response: str) -> str:
|
||||
"""This method is called after the planning chat completeion is done.
|
||||
|
||||
Args:
|
||||
response (str): The response.
|
||||
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
|
||||
def can_handle_pre_instruction(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the pre_instruction method.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the pre_instruction method."""
|
||||
return False
|
||||
|
||||
def pre_instruction(self, messages: List[str]) -> List[str]:
|
||||
"""This method is called before the instruction chat is done.
|
||||
|
||||
Args:
|
||||
messages (List[str]): The list of context messages.
|
||||
|
||||
Returns:
|
||||
List[str]: The resulting list of messages.
|
||||
"""
|
||||
|
||||
def can_handle_on_instruction(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the on_instruction method.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the on_instruction method."""
|
||||
return False
|
||||
|
||||
def on_instruction(self, messages: List[str]) -> Optional[str]:
|
||||
"""This method is called when the instruction chat is done.
|
||||
|
||||
Args:
|
||||
messages (List[str]): The list of context messages.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The resulting message.
|
||||
"""
|
||||
|
||||
def can_handle_post_instruction(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_instruction method.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_instruction method."""
|
||||
return False
|
||||
|
||||
def post_instruction(self, response: str) -> str:
|
||||
"""This method is called after the instruction chat is done.
|
||||
|
||||
Args:
|
||||
response (str): The response.
|
||||
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
|
||||
def can_handle_pre_command(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the pre_command method.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the pre_command method."""
|
||||
return False
|
||||
|
||||
def pre_command(
|
||||
self, command_name: str, arguments: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""This method is called before the command is executed.
|
||||
|
||||
Args:
|
||||
command_name (str): The command name.
|
||||
arguments (Dict[str, Any]): The arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict[str, Any]]: The command name and the arguments.
|
||||
"""
|
||||
|
||||
def can_handle_post_command(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_command method.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_command method."""
|
||||
return False
|
||||
|
||||
def post_command(self, command_name: str, response: str) -> str:
|
||||
"""This method is called after the command is executed.
|
||||
|
||||
Args:
|
||||
command_name (str): The command name.
|
||||
response (str): The response.
|
||||
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
|
||||
def can_handle_chat_completion(
|
||||
self,
|
||||
messages: list[Dict[Any, Any]],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the chat_completion method.
|
||||
|
||||
Args:
|
||||
messages (Dict[Any, Any]): The messages.
|
||||
model (str): The model name.
|
||||
temperature (float): The temperature.
|
||||
max_tokens (int): The max tokens.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the chat_completion method."""
|
||||
return False
|
||||
|
||||
def handle_chat_completion(
|
||||
self,
|
||||
messages: list[Dict[Any, Any]],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> str:
|
||||
"""This method is called when the chat completion is done.
|
||||
|
||||
Args:
|
||||
messages (Dict[Any, Any]): The messages.
|
||||
model (str): The model name.
|
||||
temperature (float): The temperature.
|
||||
max_tokens (int): The max tokens.
|
||||
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
|
||||
def can_handle_text_embedding(self, text: str) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the text_embedding method.
|
||||
Args:
|
||||
text (str): The text to be convert to embedding.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the text_embedding method."""
|
||||
return False
|
||||
|
||||
def handle_text_embedding(self, text: str) -> list:
|
||||
"""This method is called when the chat completion is done.
|
||||
Args:
|
||||
text (str): The text to be convert to embedding.
|
||||
Returns:
|
||||
list: The text embedding.
|
||||
"""
|
||||
|
||||
def can_handle_user_input(self, user_input: str) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the user_input method.
|
||||
|
||||
Args:
|
||||
user_input (str): The user input.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the user_input method."""
|
||||
return False
|
||||
|
||||
def user_input(self, user_input: str) -> str:
|
||||
"""This method is called to request user input to the user.
|
||||
|
||||
Args:
|
||||
user_input (str): The question or prompt to ask the user.
|
||||
|
||||
Returns:
|
||||
str: The user input.
|
||||
"""
|
||||
|
||||
def can_handle_report(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the report method.
|
||||
|
||||
Returns:
|
||||
bool: True if the plugin can handle the report method."""
|
||||
return False
|
||||
|
||||
def report(self, message: str) -> None:
|
||||
"""This method is called to report a message to the user.
|
||||
|
||||
Args:
|
||||
message (str): The message to report.
|
||||
"""
|
||||
@@ -1,81 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from autogpt.models.base_open_ai_plugin import BaseOpenAIPlugin
|
||||
|
||||
|
||||
class DummyPlugin(BaseOpenAIPlugin):
|
||||
"""A dummy plugin for testing purposes."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_plugin():
|
||||
"""A dummy plugin for testing purposes."""
|
||||
manifests_specs_clients = {
|
||||
"manifest": {
|
||||
"name_for_model": "Dummy",
|
||||
"schema_version": "1.0",
|
||||
"description_for_model": "A dummy plugin for testing purposes",
|
||||
},
|
||||
"client": None,
|
||||
"openapi_spec": None,
|
||||
}
|
||||
return DummyPlugin(manifests_specs_clients)
|
||||
|
||||
|
||||
def test_dummy_plugin_inheritance(dummy_plugin):
|
||||
"""Test that the DummyPlugin class inherits from the BaseOpenAIPlugin class."""
|
||||
assert isinstance(dummy_plugin, BaseOpenAIPlugin)
|
||||
|
||||
|
||||
def test_dummy_plugin_name(dummy_plugin):
|
||||
"""Test that the DummyPlugin class has the correct name."""
|
||||
assert dummy_plugin._name == "Dummy"
|
||||
|
||||
|
||||
def test_dummy_plugin_version(dummy_plugin):
|
||||
"""Test that the DummyPlugin class has the correct version."""
|
||||
assert dummy_plugin._version == "1.0"
|
||||
|
||||
|
||||
def test_dummy_plugin_description(dummy_plugin):
|
||||
"""Test that the DummyPlugin class has the correct description."""
|
||||
assert dummy_plugin._description == "A dummy plugin for testing purposes"
|
||||
|
||||
|
||||
def test_dummy_plugin_default_methods(dummy_plugin):
|
||||
"""Test that the DummyPlugin class has the correct default methods."""
|
||||
assert not dummy_plugin.can_handle_on_response()
|
||||
assert not dummy_plugin.can_handle_post_prompt()
|
||||
assert not dummy_plugin.can_handle_on_planning()
|
||||
assert not dummy_plugin.can_handle_post_planning()
|
||||
assert not dummy_plugin.can_handle_pre_instruction()
|
||||
assert not dummy_plugin.can_handle_on_instruction()
|
||||
assert not dummy_plugin.can_handle_post_instruction()
|
||||
assert not dummy_plugin.can_handle_pre_command()
|
||||
assert not dummy_plugin.can_handle_post_command()
|
||||
assert not dummy_plugin.can_handle_chat_completion(None, None, None, None)
|
||||
assert not dummy_plugin.can_handle_text_embedding(None)
|
||||
|
||||
assert dummy_plugin.on_response("hello") == "hello"
|
||||
assert dummy_plugin.post_prompt(None) is None
|
||||
assert dummy_plugin.on_planning(None, None) is None
|
||||
assert dummy_plugin.post_planning("world") == "world"
|
||||
pre_instruction = dummy_plugin.pre_instruction(
|
||||
[{"role": "system", "content": "Beep, bop, boop"}]
|
||||
)
|
||||
assert isinstance(pre_instruction, list)
|
||||
assert len(pre_instruction) == 1
|
||||
assert pre_instruction[0]["role"] == "system"
|
||||
assert pre_instruction[0]["content"] == "Beep, bop, boop"
|
||||
assert dummy_plugin.on_instruction(None) is None
|
||||
assert dummy_plugin.post_instruction("I'm a robot") == "I'm a robot"
|
||||
pre_command = dummy_plugin.pre_command("evolve", {"continuously": True})
|
||||
assert isinstance(pre_command, tuple)
|
||||
assert len(pre_command) == 2
|
||||
assert pre_command[0] == "evolve"
|
||||
assert pre_command[1]["continuously"] is True
|
||||
post_command = dummy_plugin.post_command("evolve", "upgraded successfully!")
|
||||
assert isinstance(post_command, str)
|
||||
assert post_command == "upgraded successfully!"
|
||||
assert dummy_plugin.handle_chat_completion(None, None, None, None) is None
|
||||
assert dummy_plugin.handle_text_embedding(None) is None
|
||||
@@ -1,239 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents import Agent, BaseAgent
|
||||
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.models.command import Command, CommandParameter
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
PARAMETERS = [
|
||||
CommandParameter(
|
||||
"arg1",
|
||||
spec=JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="Argument 1",
|
||||
required=True,
|
||||
),
|
||||
),
|
||||
CommandParameter(
|
||||
"arg2",
|
||||
spec=JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="Argument 2",
|
||||
required=False,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def example_command_method(arg1: int, arg2: str, agent: BaseAgent) -> str:
|
||||
"""Example function for testing the Command class."""
|
||||
# This function is static because it is not used by any other test cases.
|
||||
return f"{arg1} - {arg2}"
|
||||
|
||||
|
||||
def test_command_creation():
|
||||
"""Test that a Command object can be created with the correct attributes."""
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
|
||||
assert cmd.name == "example"
|
||||
assert cmd.description == "Example command"
|
||||
assert cmd.method == example_command_method
|
||||
assert (
|
||||
str(cmd)
|
||||
== "example: Example command. Params: (arg1: integer, arg2: Optional[string])"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_command():
|
||||
yield Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
|
||||
|
||||
def test_command_call(example_command: Command, agent: Agent):
|
||||
"""Test that Command(*args) calls and returns the result of method(*args)."""
|
||||
result = example_command(arg1=1, arg2="test", agent=agent)
|
||||
assert result == "1 - test"
|
||||
|
||||
|
||||
def test_command_call_with_invalid_arguments(example_command: Command, agent: Agent):
|
||||
"""Test that calling a Command object with invalid arguments raises a TypeError."""
|
||||
with pytest.raises(TypeError):
|
||||
example_command(arg1="invalid", does_not_exist="test", agent=agent)
|
||||
|
||||
|
||||
def test_register_command(example_command: Command):
|
||||
"""Test that a command can be registered to the registry."""
|
||||
registry = CommandRegistry()
|
||||
|
||||
registry.register(example_command)
|
||||
|
||||
assert registry.get_command(example_command.name) == example_command
|
||||
assert len(registry.commands) == 1
|
||||
|
||||
|
||||
def test_unregister_command(example_command: Command):
|
||||
"""Test that a command can be unregistered from the registry."""
|
||||
registry = CommandRegistry()
|
||||
|
||||
registry.register(example_command)
|
||||
registry.unregister(example_command)
|
||||
|
||||
assert len(registry.commands) == 0
|
||||
assert example_command.name not in registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_command_with_aliases(example_command: Command):
|
||||
example_command.aliases = ["example_alias", "example_alias_2"]
|
||||
return example_command
|
||||
|
||||
|
||||
def test_register_command_aliases(example_command_with_aliases: Command):
|
||||
"""Test that a command can be registered to the registry."""
|
||||
registry = CommandRegistry()
|
||||
command = example_command_with_aliases
|
||||
|
||||
registry.register(command)
|
||||
|
||||
assert command.name in registry
|
||||
assert registry.get_command(command.name) == command
|
||||
for alias in command.aliases:
|
||||
assert registry.get_command(alias) == command
|
||||
assert len(registry.commands) == 1
|
||||
|
||||
|
||||
def test_unregister_command_aliases(example_command_with_aliases: Command):
|
||||
"""Test that a command can be unregistered from the registry."""
|
||||
registry = CommandRegistry()
|
||||
command = example_command_with_aliases
|
||||
|
||||
registry.register(command)
|
||||
registry.unregister(command)
|
||||
|
||||
assert len(registry.commands) == 0
|
||||
assert command.name not in registry
|
||||
for alias in command.aliases:
|
||||
assert alias not in registry
|
||||
|
||||
|
||||
def test_command_in_registry(example_command_with_aliases: Command):
|
||||
"""Test that `command_name in registry` works."""
|
||||
registry = CommandRegistry()
|
||||
command = example_command_with_aliases
|
||||
|
||||
assert command.name not in registry
|
||||
assert "nonexistent_command" not in registry
|
||||
|
||||
registry.register(command)
|
||||
|
||||
assert command.name in registry
|
||||
assert "nonexistent_command" not in registry
|
||||
for alias in command.aliases:
|
||||
assert alias in registry
|
||||
|
||||
|
||||
def test_get_command(example_command: Command):
|
||||
"""Test that a command can be retrieved from the registry."""
|
||||
registry = CommandRegistry()
|
||||
|
||||
registry.register(example_command)
|
||||
retrieved_cmd = registry.get_command(example_command.name)
|
||||
|
||||
assert retrieved_cmd == example_command
|
||||
|
||||
|
||||
def test_get_nonexistent_command():
|
||||
"""Test that attempting to get a nonexistent command raises a KeyError."""
|
||||
registry = CommandRegistry()
|
||||
|
||||
assert registry.get_command("nonexistent_command") is None
|
||||
assert "nonexistent_command" not in registry
|
||||
|
||||
|
||||
def test_call_command(agent: Agent):
|
||||
"""Test that a command can be called through the registry."""
|
||||
registry = CommandRegistry()
|
||||
cmd = Command(
|
||||
name="example",
|
||||
description="Example command",
|
||||
method=example_command_method,
|
||||
parameters=PARAMETERS,
|
||||
)
|
||||
|
||||
registry.register(cmd)
|
||||
result = registry.call("example", arg1=1, arg2="test", agent=agent)
|
||||
|
||||
assert result == "1 - test"
|
||||
|
||||
|
||||
def test_call_nonexistent_command(agent: Agent):
|
||||
"""Test that attempting to call a nonexistent command raises a KeyError."""
|
||||
registry = CommandRegistry()
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
registry.call("nonexistent_command", arg1=1, arg2="test", agent=agent)
|
||||
|
||||
|
||||
def test_import_mock_commands_module():
|
||||
"""Test that the registry can import a module with mock command plugins."""
|
||||
registry = CommandRegistry()
|
||||
mock_commands_module = "tests.mocks.mock_commands"
|
||||
|
||||
registry.import_command_module(mock_commands_module)
|
||||
|
||||
assert "function_based_cmd" in registry
|
||||
assert registry.commands["function_based_cmd"].name == "function_based_cmd"
|
||||
assert (
|
||||
registry.commands["function_based_cmd"].description
|
||||
== "Function-based test command"
|
||||
)
|
||||
|
||||
|
||||
def test_import_temp_command_file_module(tmp_path: Path):
|
||||
"""
|
||||
Test that the registry can import a command plugins module from a temp file.
|
||||
Args:
|
||||
tmp_path (pathlib.Path): Path to a temporary directory.
|
||||
"""
|
||||
registry = CommandRegistry()
|
||||
|
||||
# Create a temp command file
|
||||
src = Path(os.getcwd()) / "tests/mocks/mock_commands.py"
|
||||
temp_commands_file = tmp_path / "mock_commands.py"
|
||||
shutil.copyfile(src, temp_commands_file)
|
||||
|
||||
# Add the temp directory to sys.path to make the module importable
|
||||
sys.path.append(str(tmp_path))
|
||||
|
||||
temp_commands_module = "mock_commands"
|
||||
registry.import_command_module(temp_commands_module)
|
||||
|
||||
# Remove the temp directory from sys.path
|
||||
sys.path.remove(str(tmp_path))
|
||||
|
||||
assert "function_based_cmd" in registry
|
||||
assert registry.commands["function_based_cmd"].name == "function_based_cmd"
|
||||
assert (
|
||||
registry.commands["function_based_cmd"].description
|
||||
== "Function-based test command"
|
||||
)
|
||||
@@ -14,7 +14,6 @@ from pydantic import SecretStr
|
||||
|
||||
from autogpt.app.configurator import GPT_3_MODEL, GPT_4_MODEL, apply_overrides_to_config
|
||||
from autogpt.config import Config, ConfigBuilder
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIModelName
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
ChatModelInfo,
|
||||
ModelProviderName,
|
||||
@@ -39,8 +38,8 @@ async def test_fallback_to_gpt3_if_gpt4_not_available(
|
||||
"""
|
||||
Test if models update to gpt-3.5-turbo if gpt-4 is not available.
|
||||
"""
|
||||
config.fast_llm = OpenAIModelName.GPT4_TURBO
|
||||
config.smart_llm = OpenAIModelName.GPT4_TURBO
|
||||
config.fast_llm = GPT_4_MODEL
|
||||
config.smart_llm = GPT_4_MODEL
|
||||
|
||||
mock_list_models.return_value = asyncio.Future()
|
||||
mock_list_models.return_value.set_result(
|
||||
@@ -56,8 +55,8 @@ async def test_fallback_to_gpt3_if_gpt4_not_available(
|
||||
gpt4only=False,
|
||||
)
|
||||
|
||||
assert config.fast_llm == "gpt-3.5-turbo"
|
||||
assert config.smart_llm == "gpt-3.5-turbo"
|
||||
assert config.fast_llm == GPT_3_MODEL
|
||||
assert config.smart_llm == GPT_3_MODEL
|
||||
|
||||
|
||||
def test_missing_azure_config(config: Config) -> None:
|
||||
@@ -148,7 +147,7 @@ def test_azure_config(config_with_azure: Config) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_config_gpt4only(config: Config) -> None:
|
||||
with mock.patch(
|
||||
"autogpt.core.resource.model_providers.openai.OpenAIProvider.get_available_models"
|
||||
"autogpt.core.resource.model_providers.multi.MultiProvider.get_available_models"
|
||||
) as mock_get_models:
|
||||
mock_get_models.return_value = [
|
||||
ChatModelInfo(
|
||||
@@ -168,7 +167,7 @@ async def test_create_config_gpt4only(config: Config) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_config_gpt3only(config: Config) -> None:
|
||||
with mock.patch(
|
||||
"autogpt.core.resource.model_providers.openai.OpenAIProvider.get_available_models"
|
||||
"autogpt.core.resource.model_providers.multi.MultiProvider.get_available_models"
|
||||
) as mock_get_models:
|
||||
mock_get_models.return_value = [
|
||||
ChatModelInfo(
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
import autogpt.commands.file_operations as file_ops
|
||||
import autogpt.agents.features.agent_file_manager as file_ops
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import DuplicateOperationError
|
||||
from autogpt.config import Config
|
||||
from autogpt.file_storage import FileStorage
|
||||
from autogpt.memory.vector.memory_item import MemoryItem
|
||||
@@ -38,6 +36,11 @@ def mock_MemoryItem_from_text(
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_manager_component(agent: Agent):
|
||||
return agent.file_manager
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_file_name():
|
||||
return Path("test_file.txt")
|
||||
@@ -58,197 +61,76 @@ def test_nested_file(storage: FileStorage):
|
||||
return storage.get_path("nested/test_file.txt")
|
||||
|
||||
|
||||
def test_file_operations_log():
|
||||
all_logs = (
|
||||
"File Operation Logger\n"
|
||||
"write: path/to/file1.txt #checksum1\n"
|
||||
"write: path/to/file2.txt #checksum2\n"
|
||||
"write: path/to/file3.txt #checksum3\n"
|
||||
"append: path/to/file2.txt #checksum4\n"
|
||||
"delete: path/to/file3.txt\n"
|
||||
)
|
||||
logs = all_logs.split("\n")
|
||||
|
||||
expected = [
|
||||
("write", "path/to/file1.txt", "checksum1"),
|
||||
("write", "path/to/file2.txt", "checksum2"),
|
||||
("write", "path/to/file3.txt", "checksum3"),
|
||||
("append", "path/to/file2.txt", "checksum4"),
|
||||
("delete", "path/to/file3.txt", None),
|
||||
]
|
||||
assert list(file_ops.operations_from_log(logs)) == expected
|
||||
|
||||
|
||||
def test_is_duplicate_operation(agent: Agent, mocker: MockerFixture):
|
||||
# Prepare a fake state dictionary for the function to use
|
||||
state = {
|
||||
"path/to/file1.txt": "checksum1",
|
||||
"path/to/file2.txt": "checksum2",
|
||||
}
|
||||
mocker.patch.object(file_ops, "file_operations_state", lambda _: state)
|
||||
|
||||
# Test cases with write operations
|
||||
assert (
|
||||
file_ops.is_duplicate_operation(
|
||||
"write", Path("path/to/file1.txt"), agent, "checksum1"
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
file_ops.is_duplicate_operation(
|
||||
"write", Path("path/to/file1.txt"), agent, "checksum2"
|
||||
)
|
||||
is False
|
||||
)
|
||||
assert (
|
||||
file_ops.is_duplicate_operation(
|
||||
"write", Path("path/to/file3.txt"), agent, "checksum3"
|
||||
)
|
||||
is False
|
||||
)
|
||||
# Test cases with append operations
|
||||
assert (
|
||||
file_ops.is_duplicate_operation(
|
||||
"append", Path("path/to/file1.txt"), agent, "checksum1"
|
||||
)
|
||||
is False
|
||||
)
|
||||
# Test cases with delete operations
|
||||
assert (
|
||||
file_ops.is_duplicate_operation("delete", Path("path/to/file1.txt"), agent)
|
||||
is False
|
||||
)
|
||||
assert (
|
||||
file_ops.is_duplicate_operation("delete", Path("path/to/file3.txt"), agent)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
# Test logging a file operation
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_operation(agent: Agent):
|
||||
await file_ops.log_operation("log_test", Path("path/to/test"), agent=agent)
|
||||
log_entry = agent.get_file_operation_lines()[-1]
|
||||
assert "log_test: path/to/test" in log_entry
|
||||
|
||||
|
||||
def test_text_checksum(file_content: str):
|
||||
checksum = file_ops.text_checksum(file_content)
|
||||
different_checksum = file_ops.text_checksum("other content")
|
||||
assert re.match(r"^[a-fA-F0-9]+$", checksum) is not None
|
||||
assert checksum != different_checksum
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_operation_with_checksum(agent: Agent):
|
||||
await file_ops.log_operation(
|
||||
"log_test", Path("path/to/test"), agent=agent, checksum="ABCDEF"
|
||||
)
|
||||
log_entry = agent.get_file_operation_lines()[-1]
|
||||
assert "log_test: path/to/test #ABCDEF" in log_entry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file(
|
||||
mock_MemoryItem_from_text,
|
||||
test_file_path: Path,
|
||||
file_content,
|
||||
file_manager_component,
|
||||
agent: Agent,
|
||||
):
|
||||
await agent.workspace.write_file(test_file_path.name, file_content)
|
||||
await file_ops.log_operation(
|
||||
"write", Path(test_file_path.name), agent, file_ops.text_checksum(file_content)
|
||||
)
|
||||
content = file_ops.read_file(test_file_path.name, agent=agent)
|
||||
await agent.file_manager.workspace.write_file(test_file_path.name, file_content)
|
||||
content = file_manager_component.read_file(test_file_path.name)
|
||||
assert content.replace("\r", "") == file_content
|
||||
|
||||
|
||||
def test_read_file_not_found(agent: Agent):
|
||||
def test_read_file_not_found(file_manager_component):
|
||||
filename = "does_not_exist.txt"
|
||||
with pytest.raises(FileNotFoundError):
|
||||
file_ops.read_file(filename, agent=agent)
|
||||
file_manager_component.read_file(filename)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_to_file_relative_path(test_file_name: Path, agent: Agent):
|
||||
async def test_write_to_file_relative_path(
|
||||
test_file_name: Path, file_manager_component, agent: Agent
|
||||
):
|
||||
new_content = "This is new content.\n"
|
||||
await file_ops.write_to_file(test_file_name, new_content, agent=agent)
|
||||
with open(agent.workspace.get_path(test_file_name), "r", encoding="utf-8") as f:
|
||||
await file_manager_component.write_to_file(test_file_name, new_content)
|
||||
with open(
|
||||
agent.file_manager.workspace.get_path(test_file_name), "r", encoding="utf-8"
|
||||
) as f:
|
||||
content = f.read()
|
||||
assert content == new_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_to_file_absolute_path(test_file_path: Path, agent: Agent):
|
||||
async def test_write_to_file_absolute_path(
|
||||
test_file_path: Path, file_manager_component
|
||||
):
|
||||
new_content = "This is new content.\n"
|
||||
await file_ops.write_to_file(test_file_path, new_content, agent=agent)
|
||||
await file_manager_component.write_to_file(test_file_path, new_content)
|
||||
with open(test_file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert content == new_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_logs_checksum(test_file_name: Path, agent: Agent):
|
||||
new_content = "This is new content.\n"
|
||||
new_checksum = file_ops.text_checksum(new_content)
|
||||
await file_ops.write_to_file(test_file_name, new_content, agent=agent)
|
||||
log_entry = agent.get_file_operation_lines()[-1]
|
||||
assert log_entry == f"write: {test_file_name} #{new_checksum}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_fails_if_content_exists(test_file_name: Path, agent: Agent):
|
||||
new_content = "This is new content.\n"
|
||||
await file_ops.log_operation(
|
||||
"write",
|
||||
test_file_name,
|
||||
agent=agent,
|
||||
checksum=file_ops.text_checksum(new_content),
|
||||
)
|
||||
with pytest.raises(DuplicateOperationError):
|
||||
await file_ops.write_to_file(test_file_name, new_content, agent=agent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_succeeds_if_content_different(
|
||||
test_file_path: Path, file_content: str, agent: Agent
|
||||
):
|
||||
await agent.workspace.write_file(test_file_path.name, file_content)
|
||||
await file_ops.log_operation(
|
||||
"write", Path(test_file_path.name), agent, file_ops.text_checksum(file_content)
|
||||
)
|
||||
new_content = "This is different content.\n"
|
||||
await file_ops.write_to_file(test_file_path.name, new_content, agent=agent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files(agent: Agent):
|
||||
async def test_list_files(file_manager_component, agent: Agent):
|
||||
# Create files A and B
|
||||
file_a_name = "file_a.txt"
|
||||
file_b_name = "file_b.txt"
|
||||
test_directory = Path("test_directory")
|
||||
|
||||
await agent.workspace.write_file(file_a_name, "This is file A.")
|
||||
await agent.workspace.write_file(file_b_name, "This is file B.")
|
||||
await agent.file_manager.workspace.write_file(file_a_name, "This is file A.")
|
||||
await agent.file_manager.workspace.write_file(file_b_name, "This is file B.")
|
||||
|
||||
# Create a subdirectory and place a copy of file_a in it
|
||||
agent.workspace.make_dir(test_directory)
|
||||
await agent.workspace.write_file(
|
||||
agent.file_manager.workspace.make_dir(test_directory)
|
||||
await agent.file_manager.workspace.write_file(
|
||||
test_directory / file_a_name, "This is file A in the subdirectory."
|
||||
)
|
||||
|
||||
files = file_ops.list_folder(".", agent=agent)
|
||||
files = file_manager_component.list_folder(".")
|
||||
assert file_a_name in files
|
||||
assert file_b_name in files
|
||||
assert os.path.join(test_directory, file_a_name) in files
|
||||
|
||||
# Clean up
|
||||
agent.workspace.delete_file(file_a_name)
|
||||
agent.workspace.delete_file(file_b_name)
|
||||
agent.workspace.delete_file(test_directory / file_a_name)
|
||||
agent.workspace.delete_dir(test_directory)
|
||||
agent.file_manager.workspace.delete_file(file_a_name)
|
||||
agent.file_manager.workspace.delete_file(file_b_name)
|
||||
agent.file_manager.workspace.delete_file(test_directory / file_a_name)
|
||||
agent.file_manager.workspace.delete_dir(test_directory)
|
||||
|
||||
# Case 2: Search for a file that does not exist and make sure we don't throw
|
||||
non_existent_file = "non_existent_file.txt"
|
||||
files = file_ops.list_folder("", agent=agent)
|
||||
files = file_manager_component.list_folder("")
|
||||
assert non_existent_file not in files
|
||||
|
||||
@@ -3,9 +3,9 @@ from git.exc import GitCommandError
|
||||
from git.repo.base import Repo
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import CommandExecutionError
|
||||
from autogpt.commands.git_operations import clone_repository
|
||||
from autogpt.commands.git_operations import GitOperationsComponent
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.utils.exceptions import CommandExecutionError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -13,7 +13,17 @@ def mock_clone_from(mocker):
|
||||
return mocker.patch.object(Repo, "clone_from")
|
||||
|
||||
|
||||
def test_clone_auto_gpt_repository(storage: FileStorage, mock_clone_from, agent: Agent):
|
||||
@pytest.fixture
|
||||
def git_ops_component(agent: Agent):
|
||||
return agent.git_ops
|
||||
|
||||
|
||||
def test_clone_auto_gpt_repository(
|
||||
git_ops_component: GitOperationsComponent,
|
||||
storage: FileStorage,
|
||||
mock_clone_from,
|
||||
agent: Agent,
|
||||
):
|
||||
mock_clone_from.return_value = None
|
||||
|
||||
repo = "github.com/Significant-Gravitas/Auto-GPT.git"
|
||||
@@ -23,7 +33,7 @@ def test_clone_auto_gpt_repository(storage: FileStorage, mock_clone_from, agent:
|
||||
|
||||
expected_output = f"Cloned {url} to {clone_path}"
|
||||
|
||||
clone_result = clone_repository(url=url, clone_path=clone_path, agent=agent)
|
||||
clone_result = git_ops_component.clone_repository(url, clone_path)
|
||||
|
||||
assert clone_result == expected_output
|
||||
mock_clone_from.assert_called_once_with(
|
||||
@@ -32,7 +42,12 @@ def test_clone_auto_gpt_repository(storage: FileStorage, mock_clone_from, agent:
|
||||
)
|
||||
|
||||
|
||||
def test_clone_repository_error(storage: FileStorage, mock_clone_from, agent: Agent):
|
||||
def test_clone_repository_error(
|
||||
git_ops_component: GitOperationsComponent,
|
||||
storage: FileStorage,
|
||||
mock_clone_from,
|
||||
agent: Agent,
|
||||
):
|
||||
url = "https://github.com/this-repository/does-not-exist.git"
|
||||
clone_path = storage.get_path("does-not-exist")
|
||||
|
||||
@@ -41,4 +56,4 @@ def test_clone_repository_error(storage: FileStorage, mock_clone_from, agent: Ag
|
||||
)
|
||||
|
||||
with pytest.raises(CommandExecutionError):
|
||||
clone_repository(url=url, clone_path=clone_path, agent=agent)
|
||||
git_ops_component.clone_repository(url, clone_path)
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.plugins import inspect_zip_for_modules, scan_plugins
|
||||
from autogpt.plugins.plugin_config import PluginConfig
|
||||
from autogpt.plugins.plugins_config import PluginsConfig
|
||||
|
||||
PLUGINS_TEST_DIR = "tests/unit/data/test_plugins"
|
||||
PLUGIN_TEST_ZIP_FILE = "Auto-GPT-Plugin-Test-master.zip"
|
||||
PLUGIN_TEST_INIT_PY = "Auto-GPT-Plugin-Test-master/src/auto_gpt_vicuna/__init__.py"
|
||||
PLUGIN_TEST_OPENAI = "https://weathergpt.vercel.app/"
|
||||
|
||||
|
||||
def test_scan_plugins_openai(config: Config):
|
||||
config.plugins_openai = [PLUGIN_TEST_OPENAI]
|
||||
plugins_config = config.plugins_config
|
||||
plugins_config.plugins[PLUGIN_TEST_OPENAI] = PluginConfig(
|
||||
name=PLUGIN_TEST_OPENAI, enabled=True
|
||||
)
|
||||
|
||||
# Test that the function returns the correct number of plugins
|
||||
result = scan_plugins(config)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_scan_plugins_generic(config: Config):
|
||||
# Test that the function returns the correct number of plugins
|
||||
plugins_config = config.plugins_config
|
||||
plugins_config.plugins["auto_gpt_guanaco"] = PluginConfig(
|
||||
name="auto_gpt_guanaco", enabled=True
|
||||
)
|
||||
plugins_config.plugins["AutoGPTPVicuna"] = PluginConfig(
|
||||
name="AutoGPTPVicuna", enabled=True
|
||||
)
|
||||
result = scan_plugins(config)
|
||||
plugin_class_names = [plugin.__class__.__name__ for plugin in result]
|
||||
|
||||
assert len(result) == 2
|
||||
assert "AutoGPTGuanaco" in plugin_class_names
|
||||
assert "AutoGPTPVicuna" in plugin_class_names
|
||||
|
||||
|
||||
def test_scan_plugins_not_enabled(config: Config):
|
||||
# Test that the function returns the correct number of plugins
|
||||
plugins_config = config.plugins_config
|
||||
plugins_config.plugins["auto_gpt_guanaco"] = PluginConfig(
|
||||
name="auto_gpt_guanaco", enabled=True
|
||||
)
|
||||
plugins_config.plugins["auto_gpt_vicuna"] = PluginConfig(
|
||||
name="auto_gptp_vicuna", enabled=False
|
||||
)
|
||||
result = scan_plugins(config)
|
||||
plugin_class_names = [plugin.__class__.__name__ for plugin in result]
|
||||
|
||||
assert len(result) == 1
|
||||
assert "AutoGPTGuanaco" in plugin_class_names
|
||||
assert "AutoGPTPVicuna" not in plugin_class_names
|
||||
|
||||
|
||||
def test_inspect_zip_for_modules():
|
||||
result = inspect_zip_for_modules(str(f"{PLUGINS_TEST_DIR}/{PLUGIN_TEST_ZIP_FILE}"))
|
||||
assert result == [PLUGIN_TEST_INIT_PY]
|
||||
|
||||
|
||||
def test_create_base_config(config: Config):
|
||||
"""
|
||||
Test the backwards-compatibility shim to convert old plugin allow/deny list
|
||||
to a config file.
|
||||
"""
|
||||
config.plugins_allowlist = ["a", "b"]
|
||||
config.plugins_denylist = ["c", "d"]
|
||||
|
||||
os.remove(config.plugins_config_file)
|
||||
plugins_config = PluginsConfig.load_config(
|
||||
plugins_config_file=config.plugins_config_file,
|
||||
plugins_denylist=config.plugins_denylist,
|
||||
plugins_allowlist=config.plugins_allowlist,
|
||||
)
|
||||
|
||||
# Check the structure of the plugins config data
|
||||
assert len(plugins_config.plugins) == 4
|
||||
assert plugins_config.get("a").enabled
|
||||
assert plugins_config.get("b").enabled
|
||||
assert not plugins_config.get("c").enabled
|
||||
assert not plugins_config.get("d").enabled
|
||||
|
||||
# Check the saved config file
|
||||
with open(config.plugins_config_file, "r") as saved_config_file:
|
||||
saved_config = yaml.load(saved_config_file, Loader=yaml.SafeLoader)
|
||||
|
||||
assert saved_config == {
|
||||
"a": {"enabled": True, "config": {}},
|
||||
"b": {"enabled": True, "config": {}},
|
||||
"c": {"enabled": False, "config": {}},
|
||||
"d": {"enabled": False, "config": {}},
|
||||
}
|
||||
|
||||
|
||||
def test_load_config(config: Config):
|
||||
"""
|
||||
Test that the plugin config is loaded correctly from the plugins_config.yaml file.
|
||||
"""
|
||||
# Create a test config and write it to disk
|
||||
test_config = {
|
||||
"a": {"enabled": True, "config": {"api_key": "1234"}},
|
||||
"b": {"enabled": False, "config": {}},
|
||||
}
|
||||
with open(config.plugins_config_file, "w+") as f:
|
||||
f.write(yaml.dump(test_config))
|
||||
|
||||
# Load the config from disk
|
||||
plugins_config = PluginsConfig.load_config(
|
||||
plugins_config_file=config.plugins_config_file,
|
||||
plugins_denylist=config.plugins_denylist,
|
||||
plugins_allowlist=config.plugins_allowlist,
|
||||
)
|
||||
|
||||
# Check that the loaded config is equal to the test config
|
||||
assert len(plugins_config.plugins) == 2
|
||||
assert plugins_config.get("a").enabled
|
||||
assert plugins_config.get("a").config == {"api_key": "1234"}
|
||||
assert not plugins_config.get("b").enabled
|
||||
assert plugins_config.get("b").config == {}
|
||||
@@ -10,10 +10,7 @@ import pytest
|
||||
import yaml
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from autogpt.commands.file_operations_utils import (
|
||||
decode_textual_file,
|
||||
is_file_binary_fn,
|
||||
)
|
||||
from autogpt.utils.file_operations_utils import decode_textual_file, is_file_binary_fn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from autogpt.app.utils import (
|
||||
set_env_config_value,
|
||||
)
|
||||
from autogpt.core.utils.json_utils import extract_dict_from_json
|
||||
from autogpt.utils import validate_yaml_file
|
||||
from autogpt.utils.utils import validate_yaml_file
|
||||
from tests.utils import skip_in_ci
|
||||
|
||||
|
||||
|
||||
@@ -4,23 +4,32 @@ import pytest
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import ConfigurationError
|
||||
from autogpt.commands.web_search import google, safe_google_results, web_search
|
||||
from autogpt.commands.web_search import WebSearchComponent
|
||||
from autogpt.utils.exceptions import ConfigurationError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def web_search_component(agent: Agent):
|
||||
return agent.web_search
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query, expected_output",
|
||||
[("test", "test"), (["test1", "test2"], '["test1", "test2"]')],
|
||||
)
|
||||
def test_safe_google_results(query, expected_output):
|
||||
result = safe_google_results(query)
|
||||
@pytest.fixture
|
||||
def test_safe_google_results(
|
||||
query, expected_output, web_search_component: WebSearchComponent
|
||||
):
|
||||
result = web_search_component.safe_google_results(query)
|
||||
assert isinstance(result, str)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_safe_google_results_invalid_input():
|
||||
@pytest.fixture
|
||||
def test_safe_google_results_invalid_input(web_search_component: WebSearchComponent):
|
||||
with pytest.raises(AttributeError):
|
||||
safe_google_results(123)
|
||||
web_search_component.safe_google_results(123) # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -37,13 +46,18 @@ def test_safe_google_results_invalid_input():
|
||||
],
|
||||
)
|
||||
def test_google_search(
|
||||
query, num_results, expected_output_parts, return_value, mocker, agent: Agent
|
||||
query,
|
||||
num_results,
|
||||
expected_output_parts,
|
||||
return_value,
|
||||
mocker,
|
||||
web_search_component: WebSearchComponent,
|
||||
):
|
||||
mock_ddg = mocker.Mock()
|
||||
mock_ddg.return_value = return_value
|
||||
|
||||
mocker.patch("autogpt.commands.web_search.DDGS.text", mock_ddg)
|
||||
actual_output = web_search(query, agent=agent, num_results=num_results)
|
||||
actual_output = web_search_component.web_search(query, num_results=num_results)
|
||||
for o in expected_output_parts:
|
||||
assert o in actual_output
|
||||
|
||||
@@ -82,11 +96,11 @@ def test_google_official_search(
|
||||
expected_output,
|
||||
search_results,
|
||||
mock_googleapiclient,
|
||||
agent: Agent,
|
||||
web_search_component: WebSearchComponent,
|
||||
):
|
||||
mock_googleapiclient.return_value = search_results
|
||||
actual_output = google(query, agent=agent, num_results=num_results)
|
||||
assert actual_output == safe_google_results(expected_output)
|
||||
actual_output = web_search_component.google(query, num_results=num_results)
|
||||
assert actual_output == web_search_component.safe_google_results(expected_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -115,7 +129,7 @@ def test_google_official_search_errors(
|
||||
mock_googleapiclient,
|
||||
http_code,
|
||||
error_msg,
|
||||
agent: Agent,
|
||||
web_search_component: WebSearchComponent,
|
||||
):
|
||||
class resp:
|
||||
def __init__(self, _status, _reason):
|
||||
@@ -133,4 +147,4 @@ def test_google_official_search_errors(
|
||||
|
||||
mock_googleapiclient.side_effect = error
|
||||
with pytest.raises(expected_error_type):
|
||||
google(query, agent=agent, num_results=num_results)
|
||||
web_search_component.google(query, num_results=num_results)
|
||||
|
||||
2916
autogpts/forge/poetry.lock
generated
2916
autogpts/forge/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -28,6 +28,7 @@ agbenchmark = { path = "../../benchmark", optional = true }
|
||||
# agbenchmark = {git = "https://github.com/Significant-Gravitas/AutoGPT.git", subdirectory = "benchmark", optional = true}
|
||||
webdriver-manager = "^4.0.1"
|
||||
google-cloud-storage = "^2.13.0"
|
||||
click-default-group = "^1.2.4"
|
||||
|
||||
[tool.poetry.extras]
|
||||
benchmark = ["agbenchmark"]
|
||||
|
||||
@@ -18,7 +18,9 @@ class TestPasswordGenerator(unittest.TestCase):
|
||||
def test_password_content(self):
|
||||
password = password_generator.generate_password()
|
||||
self.assertTrue(any(c.isdigit() for c in password))
|
||||
self.assertTrue(any(c in password_generator.string.punctuation for c in password))
|
||||
self.assertTrue(
|
||||
any(c in password_generator.string.punctuation for c in password)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
from sample_code import multiply_int
|
||||
# from sample_code import multiply_int
|
||||
|
||||
|
||||
def test_multiply_int(num: int, multiplier, expected_result: int) -> None:
|
||||
result = multiply_int(num, multiplier)
|
||||
print(result)
|
||||
assert (
|
||||
result == expected_result
|
||||
), f"AssertionError: Expected the output to be {expected_result}"
|
||||
# def test_multiply_int(num: int, multiplier, expected_result: int) -> None:
|
||||
# result = multiply_int(num, multiplier)
|
||||
# print(result)
|
||||
# assert (
|
||||
# result == expected_result
|
||||
# ), f"AssertionError: Expected the output to be {expected_result}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# create a trivial test that has 4 as the num, and 2 as the multiplier. Make sure to fill in the expected result
|
||||
num =
|
||||
multiplier =
|
||||
expected_result =
|
||||
test_multiply_int()
|
||||
# if __name__ == "__main__":
|
||||
# # create a trivial test that has 4 as the num, and 2 as the multiplier. Make sure to fill in the expected result
|
||||
# num =
|
||||
# multiplier =
|
||||
# expected_result =
|
||||
# test_multiply_int()
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from import
|
||||
|
||||
|
||||
def test_two_sum(nums: List, target: int, expected_result: List[int]) -> None:
|
||||
result = two_sum(nums, target)
|
||||
|
||||
@@ -21,7 +21,6 @@ def generate_password(length: int = 8) -> str:
|
||||
|
||||
if __name__ == "__main__":
|
||||
password_length = (
|
||||
int(sys.argv[sys.argv.index("--length") + 1])
|
||||
if "--length" in sys.argv else 8
|
||||
int(sys.argv[sys.argv.index("--length") + 1]) if "--length" in sys.argv else 8
|
||||
)
|
||||
print(generate_password(password_length))
|
||||
|
||||
@@ -18,7 +18,9 @@ class TestPasswordGenerator(unittest.TestCase):
|
||||
def test_password_content(self):
|
||||
password = password_generator.generate_password()
|
||||
self.assertTrue(any(c.isdigit() for c in password))
|
||||
self.assertTrue(any(c in password_generator.string.punctuation for c in password))
|
||||
self.assertTrue(
|
||||
any(c in password_generator.string.punctuation for c in password)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
from typing import Dict
|
||||
|
||||
from abstract_class import (AbstractBattleship, Game, GameStatus,
|
||||
ShipPlacement, Turn, TurnResponse)
|
||||
from abstract_class import (
|
||||
AbstractBattleship,
|
||||
Game,
|
||||
GameStatus,
|
||||
ShipPlacement,
|
||||
Turn,
|
||||
TurnResponse,
|
||||
)
|
||||
|
||||
|
||||
class Battleship(AbstractBattleship):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user