feat(forge): Component-specific configuration (#7170)

Remove many env vars and use component-level configuration that could be loaded from file instead.

### Changed

- `BaseAgent` provides `serialize_configs` and `deserialize_configs` that can save and load all component configuration as json `str`. Deserialized components/values overwrite existing values, so not all values need to be present in the serialized config.
- Decoupled `forge/content_processing/text.py` from `Config`
- Kept `execute_local_commands` in `Config` because it's needed to know if OS info should be included in the prompt
- Updated docs to reflect changes
- Renamed `Config` to `AppConfig`

### Added

- Added `ConfigurableComponent` class for components and following configs:
  - `ActionHistoryConfiguration`
  - `CodeExecutorConfiguration`
  - `FileManagerConfiguration` - now file manager allows to have multiple agents using the same workspace
  - `GitOperationsConfiguration`
  - `ImageGeneratorConfiguration`
  - `WebSearchConfiguration`
  - `WebSeleniumConfiguration`
- `BaseConfig` in `forge` and moved `Config` (now inherits from `BaseConfig`) back to `autogpt`
- Required `config_class` attribute for the `ConfigurableComponent` class that should be set to configuration class for a component
`--component-config-file` CLI option and `COMPONENT_CONFIG_FILE` env var and field in `Config`. This option allows to load configuration from a specific file, CLI option takes precedence over env var.
- Added comments to config models

### Removed

- Unused `change_agent_id` method from `FileManagerComponent`
- Unused `allow_downloads` from `Config` and CLI options (it should be in web component config if needed)
- CLI option `--browser-name` (the option is inside `WebSeleniumConfiguration`)
- Unused `workspace_directory` from CLI options
- No longer needed variables from `Config` and docs
- Unused fields from `Config`: `image_size`, `audio_to_text_provider`, `huggingface_audio_to_text_model`
- Removed `files` and `workspace` class attributes from `FileManagerComponent`
This commit is contained in:
Krzysztof Czerwinski
2024-06-19 09:14:01 +01:00
committed by GitHub
parent 02dc198a9f
commit c19ab2b24f
47 changed files with 772 additions and 722 deletions

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import copy
import inspect
import json
import logging
from abc import ABCMeta, abstractmethod
from typing import (
@@ -18,12 +19,13 @@ from typing import (
)
from colorama import Fore
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, parse_raw_as, validator
from forge.agent import protocols
from forge.agent.components import (
AgentComponent,
ComponentEndpointError,
ConfigurableComponent,
EndpointPipelineError,
)
from forge.config.ai_directives import AIDirectives
@@ -45,6 +47,11 @@ DEFAULT_TRIGGERING_PROMPT = (
)
# HACK: This is a workaround wrapper to de/serialize component configs until pydantic v2
class ModelContainer(BaseModel):
models: dict[str, BaseModel]
class BaseAgentConfiguration(SystemConfiguration):
allow_fs_access: bool = UserConfigurable(default=False)
@@ -82,9 +89,6 @@ class BaseAgentConfiguration(SystemConfiguration):
defaults to 75% of `llm.max_tokens`.
"""
summary_max_tlength: Optional[int] = None
# TODO: move to ActionHistoryConfiguration
@validator("use_functions_api")
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
if v:
@@ -272,6 +276,30 @@ class BaseAgent(Generic[AnyProposal], metaclass=AgentMeta):
raise e
return method_result
def dump_component_configs(self) -> str:
configs = {}
for component in self.components:
if isinstance(component, ConfigurableComponent):
config_type_name = component.config.__class__.__name__
configs[config_type_name] = component.config
data = ModelContainer(models=configs).json()
raw = parse_raw_as(dict[str, dict[str, Any]], data)
return json.dumps(raw["models"], indent=4)
def load_component_configs(self, serialized_configs: str):
configs_dict = parse_raw_as(dict[str, dict[str, Any]], serialized_configs)
for component in self.components:
if not isinstance(component, ConfigurableComponent):
continue
config_type = type(component.config)
config_type_name = config_type.__name__
if config_type_name in configs_dict:
# Parse the serialized data and update the existing config
updated_data = configs_dict[config_type_name]
data = {**component.config.dict(), **updated_data}
component.config = component.config.__class__(**data)
def _collect_components(self):
components = [
getattr(self, attr)

View File

@@ -1,9 +1,14 @@
from __future__ import annotations
from abc import ABC
from typing import Callable, TypeVar
from typing import Callable, ClassVar, Generic, Optional, TypeVar
T = TypeVar("T", bound="AgentComponent")
from pydantic import BaseModel
from forge.models.config import _update_user_config_from_env, deep_update
AC = TypeVar("AC", bound="AgentComponent")
BM = TypeVar("BM", bound=BaseModel)
class AgentComponent(ABC):
@@ -24,7 +29,7 @@ class AgentComponent(ABC):
"""Return the reason this component is disabled."""
return self._disabled_reason
def run_after(self: T, *components: type[AgentComponent] | AgentComponent) -> T:
def run_after(self: AC, *components: type[AgentComponent] | AgentComponent) -> AC:
"""Set the components that this component should run after."""
for component in components:
t = component if isinstance(component, type) else type(component)
@@ -33,6 +38,39 @@ class AgentComponent(ABC):
return self
class ConfigurableComponent(ABC, Generic[BM]):
"""A component that can be configured with a Pydantic model."""
config_class: ClassVar[type[BM]] # type: ignore
def __init__(self, configuration: Optional[BM]):
self._config: Optional[BM] = None
if configuration is not None:
self.config = configuration
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if getattr(cls, "config_class", None) is None:
raise NotImplementedError(
f"ConfigurableComponent subclass {cls.__name__} "
"must define config_class class attribute."
)
@property
def config(self) -> BM:
if not hasattr(self, "_config") or self._config is None:
self.config = self.config_class()
return self._config # type: ignore
@config.setter
def config(self, config: BM):
if not hasattr(self, "_config") or self._config is None:
# Load configuration from environment variables
updated = _update_user_config_from_env(config)
config = self.config_class(**deep_update(config.dict(), updated))
self._config = config
class ComponentEndpointError(Exception):
"""Error of a single protocol method on a component."""

View File

@@ -30,6 +30,48 @@ class MyAgent(BaseAgent):
self.some_component = SomeComponent(self.hello_component)
```
## Component configuration
Each component can have its own configuration defined using a regular pydantic `BaseModel`.
To ensure the configuration is loaded from the file correctly, the component must inherit from `ConfigurableComponent[T]` where `T` is the configuration model it uses.
`ConfigurableComponent` provides a `config` attribute that holds the configuration instance.
It's possible to either set the `config` attribute directly or pass the configuration instance to the component's constructor.
Extra configuration (i.e. for components that are not part of the agent) can be passed and will be silently ignored. Extra config won't be applied even if the component is added later.
To see the configuration of built-in components visit [Built-in Components](./built-in-components.md).
```py
from pydantic import BaseModel
from forge.agent.components import ConfigurableComponent
class MyConfig(BaseModel):
some_value: str
class MyComponent(AgentComponent, ConfigurableComponent[MyConfig]):
def __init__(self, config: MyConfig):
super().__init__(config)
# This has the same effect as above:
# self.config = config
def get_some_value(self) -> str:
# Access the configuration like a regular model
return self.config.some_value
```
### Sensitive information
While it's possible to pass sensitive data directly in code to the configuration it's recommended to use `UserConfigurable(from_env="ENV_VAR_NAME", exclude=True)` field for sensitive data like API keys.
The data will be loaded from the environment variable but keep in mind that value passed in code takes precedence.
All fields, even excluded ones (`exclude=True`) will be loaded when the configuration is loaded from the file.
Exclusion allows you to skip them during *serialization*, non excluded `SecretStr` will be serialized literally as a `"**********"` string.
```py
from pydantic import BaseModel, SecretStr
from forge.models.config import UserConfigurable
class SensitiveConfig(BaseModel):
api_key: SecretStr = UserConfigurable(from_env="API_KEY", exclude=True)
```
## Ordering components
The execution order of components is important because some may depend on the results of the previous ones.
@@ -72,6 +114,7 @@ class MyAgent(Agent):
## Disabling components
You can control which components are enabled by setting their `_enabled` attribute.
Components are *enabled* by default.
Either provide a `bool` value or a `Callable[[], bool]`, will be checked each time
the component is about to be executed. This way you can dynamically enable or disable
components based on some conditions.

View File

@@ -1,38 +1,54 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Iterator, Optional
from typing import Callable, Iterator, Optional
from pydantic import BaseModel
from forge.agent.components import ConfigurableComponent
from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider
from forge.llm.prompting.utils import indent
from forge.llm.providers import ChatMessage, MultiProvider
if TYPE_CHECKING:
from forge.config.config import Config
from forge.llm.providers.multi import ModelName
from forge.llm.providers.openai import OpenAIModelName
from .model import ActionResult, AnyProposal, Episode, EpisodicActionHistory
class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExecute):
class ActionHistoryConfiguration(BaseModel):
model_name: ModelName = OpenAIModelName.GPT3
"""Name of the llm model used to compress the history"""
max_tokens: int = 1024
"""Maximum number of tokens to use up with generated history messages"""
spacy_language_model: str = "en_core_web_sm"
"""Language model used for summary chunking using spacy"""
class ActionHistoryComponent(
MessageProvider,
AfterParse[AnyProposal],
AfterExecute,
ConfigurableComponent[ActionHistoryConfiguration],
):
"""Keeps track of the event history and provides a summary of the steps."""
config_class = ActionHistoryConfiguration
def __init__(
self,
event_history: EpisodicActionHistory[AnyProposal],
max_tokens: int,
count_tokens: Callable[[str], int],
legacy_config: Config,
llm_provider: MultiProvider,
config: Optional[ActionHistoryConfiguration] = None,
) -> None:
ConfigurableComponent.__init__(self, config)
self.event_history = event_history
self.max_tokens = max_tokens
self.count_tokens = count_tokens
self.legacy_config = legacy_config
self.llm_provider = llm_provider
def get_messages(self) -> Iterator[ChatMessage]:
if progress := self._compile_progress(
self.event_history.episodes,
self.max_tokens,
self.config.max_tokens,
self.count_tokens,
):
yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}")
@@ -43,7 +59,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExec
async def after_execute(self, result: ActionResult) -> None:
self.event_history.register_result(result)
await self.event_history.handle_compression(
self.llm_provider, self.legacy_config
self.llm_provider, self.config.model_name, self.config.spacy_language_model
)
def _compile_progress(

View File

@@ -8,11 +8,11 @@ from pydantic.generics import GenericModel
from forge.content_processing.text import summarize_text
from forge.llm.prompting.utils import format_numbered_list, indent
from forge.llm.providers.multi import ModelName
from forge.models.action import ActionResult, AnyProposal
from forge.models.utils import ModelWithSummary
if TYPE_CHECKING:
from forge.config.config import Config
from forge.llm.providers import MultiProvider
@@ -108,7 +108,10 @@ class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
self.cursor = len(self.episodes)
async def handle_compression(
self, llm_provider: MultiProvider, app_config: Config
self,
llm_provider: MultiProvider,
model_name: ModelName,
spacy_model: str,
) -> None:
"""Compresses each episode in the action history using an LLM.
@@ -131,7 +134,8 @@ class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
episode.format(),
instruction=compress_instruction,
llm_provider=llm_provider,
config=app_config,
model_name=model_name,
spacy_model=spacy_model,
)
for episode in episodes_to_summarize
]

View File

@@ -1,9 +1,4 @@
from .code_executor import (
ALLOWLIST_CONTROL,
DENYLIST_CONTROL,
CodeExecutionError,
CodeExecutorComponent,
)
from .code_executor import CodeExecutionError, CodeExecutorComponent
__all__ = [
"ALLOWLIST_CONTROL",

View File

@@ -5,16 +5,16 @@ import shlex
import string
import subprocess
from pathlib import Path
from typing import Iterator
from typing import Iterator, Literal, Optional
import docker
from docker.errors import DockerException, ImageNotFound, NotFound
from docker.models.containers import Container as DockerContainer
from pydantic import BaseModel, Field
from forge.agent import BaseAgentSettings
from forge.agent.components import ConfigurableComponent
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.file_storage import FileStorage
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import (
@@ -25,9 +25,6 @@ from forge.utils.exceptions import (
logger = logging.getLogger(__name__)
ALLOWLIST_CONTROL = "allowlist"
DENYLIST_CONTROL = "denylist"
def we_are_running_in_a_docker_container() -> bool:
"""Check if we are running in a Docker container
@@ -56,15 +53,45 @@ class CodeExecutionError(CommandExecutionError):
"""The operation (an attempt to run arbitrary code) returned an error"""
class CodeExecutorComponent(CommandProvider):
class CodeExecutorConfiguration(BaseModel):
execute_local_commands: bool = False
"""Enable shell command execution"""
shell_command_control: Literal["allowlist", "denylist"] = "allowlist"
"""Controls which list is used"""
shell_allowlist: list[str] = Field(default_factory=list)
"""List of allowed shell commands"""
shell_denylist: list[str] = Field(default_factory=list)
"""List of prohibited shell commands"""
docker_container_name: str = "agent_sandbox"
"""Name of the Docker container used for code execution"""
class CodeExecutorComponent(
CommandProvider, ConfigurableComponent[CodeExecutorConfiguration]
):
"""Provides commands to execute Python code and shell commands."""
config_class = CodeExecutorConfiguration
def __init__(
self, workspace: FileStorage, state: BaseAgentSettings, config: Config
self,
workspace: FileStorage,
config: Optional[CodeExecutorConfiguration] = None,
):
ConfigurableComponent.__init__(self, config)
self.workspace = workspace
self.state = state
self.legacy_config = config
# Change container name if it's empty or default to prevent different agents
# from using the same container
default_container_name = self.config.__fields__["docker_container_name"].default
if (
not self.config.docker_container_name
or self.config.docker_container_name == default_container_name
):
random_suffix = "".join(random.choices(string.ascii_lowercase, k=8))
self.config.docker_container_name = (
f"{default_container_name}_{random_suffix}"
)
if not we_are_running_in_a_docker_container() and not is_docker_available():
logger.info(
@@ -72,7 +99,7 @@ class CodeExecutorComponent(CommandProvider):
"The code execution commands will not be available."
)
if not self.legacy_config.execute_local_commands:
if not self.config.execute_local_commands:
logger.info(
"Local shell commands are disabled. To enable them,"
" set EXECUTE_LOCAL_COMMANDS to 'True' in your config file."
@@ -83,7 +110,7 @@ class CodeExecutorComponent(CommandProvider):
yield self.execute_python_code
yield self.execute_python_file
if self.legacy_config.execute_local_commands:
if self.config.execute_local_commands:
yield self.execute_shell
yield self.execute_shell_popen
@@ -192,7 +219,7 @@ class CodeExecutorComponent(CommandProvider):
logger.debug("App is not running in a Docker container")
return self._run_python_code_in_docker(file_path, args)
def validate_command(self, command_line: str, config: Config) -> tuple[bool, bool]:
def validate_command(self, command_line: str) -> tuple[bool, bool]:
"""Check whether a command is allowed and whether it may be executed in a shell.
If shell command control is enabled, we disallow executing in a shell, because
@@ -211,10 +238,10 @@ class CodeExecutorComponent(CommandProvider):
command_name = shlex.split(command_line)[0]
if config.shell_command_control == ALLOWLIST_CONTROL:
return command_name in config.shell_allowlist, False
elif config.shell_command_control == DENYLIST_CONTROL:
return command_name not in config.shell_denylist, False
if self.config.shell_command_control == "allowlist":
return command_name in self.config.shell_allowlist, False
elif self.config.shell_command_control == "denylist":
return command_name not in self.config.shell_denylist, False
else:
return True, True
@@ -238,9 +265,7 @@ class CodeExecutorComponent(CommandProvider):
Returns:
str: The output of the command
"""
allow_execute, allow_shell = self.validate_command(
command_line, self.legacy_config
)
allow_execute, allow_shell = self.validate_command(command_line)
if not allow_execute:
logger.info(f"Command '{command_line}' not allowed")
raise OperationNotAllowedError("This shell command is not allowed.")
@@ -287,9 +312,7 @@ class CodeExecutorComponent(CommandProvider):
Returns:
str: Description of the fact that the process started and its id
"""
allow_execute, allow_shell = self.validate_command(
command_line, self.legacy_config
)
allow_execute, allow_shell = self.validate_command(command_line)
if not allow_execute:
logger.info(f"Command '{command_line}' not allowed")
raise OperationNotAllowedError("This shell command is not allowed.")
@@ -320,12 +343,10 @@ class CodeExecutorComponent(CommandProvider):
"""Run a Python script in a Docker container"""
file_path = self.workspace.get_path(filename)
try:
assert self.state.agent_id, "Need Agent ID to attach Docker container"
client = docker.from_env()
image_name = "python:3-alpine"
container_is_fresh = False
container_name = f"{self.state.agent_id}_sandbox"
container_name = self.config.docker_container_name
with self.workspace.mount() as local_path:
try:
container: DockerContainer = client.containers.get(

View File

@@ -3,7 +3,10 @@ import os
from pathlib import Path
from typing import Iterator, Optional
from pydantic import BaseModel
from forge.agent import BaseAgentSettings
from forge.agent.components import ConfigurableComponent
from forge.agent.protocols import CommandProvider, DirectiveProvider
from forge.command import Command, command
from forge.file_storage.base import FileStorage
@@ -13,67 +16,89 @@ from forge.utils.file_operations import decode_textual_file
logger = logging.getLogger(__name__)
class FileManagerComponent(DirectiveProvider, CommandProvider):
class FileManagerConfiguration(BaseModel):
storage_path: str
"""Path to agent files, e.g. state"""
workspace_path: str
"""Path to files that agent has access to"""
class Config:
# Prevent mutation of the configuration
# as this wouldn't be reflected in the file storage
allow_mutation = False
class FileManagerComponent(
DirectiveProvider, CommandProvider, ConfigurableComponent[FileManagerConfiguration]
):
"""
Adds general file manager (e.g. Agent state),
workspace manager (e.g. Agent output files) support and
commands to perform operations on files and folders.
"""
files: FileStorage
"""Agent-related files, e.g. state, logs.
Use `workspace` to access the agent's workspace files."""
workspace: FileStorage
"""Workspace that the agent has access to, e.g. for reading/writing files.
Use `files` to access agent-related files, e.g. state, logs."""
config_class = FileManagerConfiguration
STATE_FILE = "state.json"
"""The name of the file where the agent's state is stored."""
def __init__(self, state: BaseAgentSettings, file_storage: FileStorage):
self.state = state
def __init__(
self,
file_storage: FileStorage,
agent_state: BaseAgentSettings,
config: Optional[FileManagerConfiguration] = None,
):
"""Initialise the FileManagerComponent.
Either `agent_id` or `config` must be provided.
if not state.agent_id:
Args:
file_storage (FileStorage): The file storage instance to use.
state (BaseAgentSettings): The agent's state.
config (FileManagerConfiguration, optional): The configuration for
the file manager. Defaults to None.
"""
if not agent_state.agent_id:
raise ValueError("Agent must have an ID.")
self.files = file_storage.clone_with_subroot(f"agents/{state.agent_id}/")
self.workspace = file_storage.clone_with_subroot(
f"agents/{state.agent_id}/workspace"
)
self.agent_state = agent_state
if not config:
storage_path = f"agents/{self.agent_state.agent_id}/"
workspace_path = f"agents/{self.agent_state.agent_id}/workspace"
ConfigurableComponent.__init__(
self,
FileManagerConfiguration(
storage_path=storage_path, workspace_path=workspace_path
),
)
else:
ConfigurableComponent.__init__(self, config)
self.storage = file_storage.clone_with_subroot(self.config.storage_path)
"""Agent-related files, e.g. state, logs.
Use `workspace` to access the agent's workspace files."""
self.workspace = file_storage.clone_with_subroot(self.config.workspace_path)
"""Workspace that the agent has access to, e.g. for reading/writing files.
Use `storage` to access agent-related files, e.g. state, logs."""
self._file_storage = file_storage
async def save_state(self, save_as: Optional[str] = None) -> None:
"""Save the agent's state to the state file."""
state: BaseAgentSettings = getattr(self, "state")
if save_as:
temp_id = state.agent_id
state.agent_id = save_as
self._file_storage.make_dir(f"agents/{save_as}")
async def save_state(self, save_as_id: Optional[str] = None) -> None:
"""Save the agent's data and state."""
if save_as_id:
self._file_storage.make_dir(f"agents/{save_as_id}")
# Save state
await self._file_storage.write_file(
f"agents/{save_as}/{self.STATE_FILE}", state.json()
f"agents/{save_as_id}/{self.STATE_FILE}", self.agent_state.json()
)
# Copy workspace
self._file_storage.copy(
f"agents/{temp_id}/workspace",
f"agents/{save_as}/workspace",
self.config.workspace_path,
f"agents/{save_as_id}/workspace",
)
state.agent_id = temp_id
else:
await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
def change_agent_id(self, new_id: str):
"""Change the agent's ID and update the file storage accordingly."""
state: BaseAgentSettings = getattr(self, "state")
# Rename the agent's files and workspace
self._file_storage.rename(f"agents/{state.agent_id}", f"agents/{new_id}")
# Update the file storage objects
self.files = self._file_storage.clone_with_subroot(f"agents/{new_id}/")
self.workspace = self._file_storage.clone_with_subroot(
f"agents/{new_id}/workspace"
)
state.agent_id = new_id
await self.storage.write_file(
self.storage.root / self.STATE_FILE, self.agent_state.json()
)
def get_resources(self) -> Iterator[str]:
yield "The ability to read and write files."

View File

@@ -1,23 +1,36 @@
from pathlib import Path
from typing import Iterator
from typing import Iterator, Optional
from git.repo import Repo
from pydantic import BaseModel, SecretStr
from forge.agent.components import ConfigurableComponent
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.models.config import UserConfigurable
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import CommandExecutionError
from forge.utils.url_validator import validate_url
class GitOperationsComponent(CommandProvider):
class GitOperationsConfiguration(BaseModel):
github_username: Optional[str] = UserConfigurable(from_env="GITHUB_USERNAME")
github_api_key: Optional[SecretStr] = UserConfigurable(
from_env="GITHUB_API_KEY", exclude=True
)
class GitOperationsComponent(
CommandProvider, ConfigurableComponent[GitOperationsConfiguration]
):
"""Provides commands to perform Git operations."""
def __init__(self, config: Config):
self._enabled = bool(config.github_username and config.github_api_key)
config_class = GitOperationsConfiguration
def __init__(self, config: Optional[GitOperationsConfiguration] = None):
ConfigurableComponent.__init__(self, config)
self._enabled = bool(self.config.github_username and self.config.github_api_key)
self._disabled_reason = "Configure github_username and github_api_key."
self.legacy_config = config
def get_commands(self) -> Iterator[Command]:
yield self.clone_repository
@@ -48,9 +61,13 @@ class GitOperationsComponent(CommandProvider):
str: The result of the clone operation.
"""
split_url = url.split("//")
auth_repo_url = (
f"//{self.legacy_config.github_username}:"
f"{self.legacy_config.github_api_key}@".join(split_url)
api_key = (
self.config.github_api_key.get_secret_value()
if self.config.github_api_key
else None
)
auth_repo_url = f"//{self.config.github_username}:" f"{api_key}@".join(
split_url
)
try:
Repo.clone_from(url=auth_repo_url, to_path=clone_path)

View File

@@ -1,5 +1,3 @@
"""Commands to generate images based on text input"""
import io
import json
import logging
@@ -7,35 +5,61 @@ import time
import uuid
from base64 import b64decode
from pathlib import Path
from typing import Iterator
from typing import Iterator, Literal, Optional
import requests
from openai import OpenAI
from PIL import Image
from pydantic import BaseModel, SecretStr
from forge.agent.components import ConfigurableComponent
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.file_storage import FileStorage
from forge.llm.providers.openai import OpenAICredentials
from forge.models.config import UserConfigurable
from forge.models.json_schema import JSONSchema
logger = logging.getLogger(__name__)
class ImageGeneratorComponent(CommandProvider):
class ImageGeneratorConfiguration(BaseModel):
image_provider: Literal["dalle", "huggingface", "sdwebui"] = "dalle"
huggingface_image_model: str = "CompVis/stable-diffusion-v1-4"
huggingface_api_token: Optional[SecretStr] = UserConfigurable(
from_env="HUGGINGFACE_API_TOKEN", exclude=True
)
sd_webui_url: str = "http://localhost:7860"
sd_webui_auth: Optional[SecretStr] = UserConfigurable(
from_env="SD_WEBUI_AUTH", exclude=True
)
class ImageGeneratorComponent(
CommandProvider, ConfigurableComponent[ImageGeneratorConfiguration]
):
"""A component that provides commands to generate images from text prompts."""
def __init__(self, workspace: FileStorage, config: Config):
self._enabled = bool(config.image_provider)
config_class = ImageGeneratorConfiguration
def __init__(
self,
workspace: FileStorage,
config: Optional[ImageGeneratorConfiguration] = None,
openai_credentials: Optional[OpenAICredentials] = None,
):
"""openai_credentials only needed for `dalle` provider."""
ConfigurableComponent.__init__(self, config)
self.openai_credentials = openai_credentials
self._enabled = bool(self.config.image_provider)
self._disabled_reason = "No image provider set."
self.workspace = workspace
self.legacy_config = config
def get_commands(self) -> Iterator[Command]:
if (
self.legacy_config.openai_credentials
or self.legacy_config.huggingface_api_token
or self.legacy_config.sd_webui_auth
self.openai_credentials
or self.config.huggingface_api_token
or self.config.sd_webui_auth
):
yield self.generate_image
@@ -48,7 +72,7 @@ class ImageGeneratorComponent(CommandProvider):
),
"size": JSONSchema(
type=JSONSchema.Type.INTEGER,
description="The size of the image",
description="The size of the image [256, 512, 1024]",
required=False,
),
},
@@ -65,22 +89,21 @@ class ImageGeneratorComponent(CommandProvider):
str: The filename of the image
"""
filename = self.workspace.root / f"{str(uuid.uuid4())}.jpg"
cfg = self.legacy_config
if cfg.openai_credentials and (
cfg.image_provider == "dalle"
or not (cfg.huggingface_api_token or cfg.sd_webui_url)
if self.openai_credentials and (
self.config.image_provider == "dalle"
or not (self.config.huggingface_api_token or self.config.sd_webui_url)
):
return self.generate_image_with_dalle(prompt, filename, size)
elif cfg.huggingface_api_token and (
cfg.image_provider == "huggingface"
or not (cfg.openai_credentials or cfg.sd_webui_url)
elif self.config.huggingface_api_token and (
self.config.image_provider == "huggingface"
or not (self.openai_credentials or self.config.sd_webui_url)
):
return self.generate_image_with_hf(prompt, filename)
elif cfg.sd_webui_url and (
cfg.image_provider == "sdwebui" or cfg.sd_webui_auth
elif self.config.sd_webui_url and (
self.config.image_provider == "sdwebui" or self.config.sd_webui_auth
):
return self.generate_image_with_sd_webui(prompt, filename, size)
@@ -96,13 +119,15 @@ class ImageGeneratorComponent(CommandProvider):
Returns:
str: The filename of the image
"""
API_URL = f"https://api-inference.huggingface.co/models/{self.legacy_config.huggingface_image_model}" # noqa: E501
if self.legacy_config.huggingface_api_token is None:
API_URL = f"https://api-inference.huggingface.co/models/{self.config.huggingface_image_model}" # noqa: E501
if self.config.huggingface_api_token is None:
raise ValueError(
"You need to set your Hugging Face API token in the config file."
)
headers = {
"Authorization": f"Bearer {self.legacy_config.huggingface_api_token}",
"Authorization": (
f"Bearer {self.config.huggingface_api_token.get_secret_value()}"
),
"X-Use-Cache": "false",
}
@@ -156,7 +181,7 @@ class ImageGeneratorComponent(CommandProvider):
Returns:
str: The filename of the image
"""
assert self.legacy_config.openai_credentials # otherwise this tool is disabled
assert self.openai_credentials # otherwise this tool is disabled
# Check for supported image sizes
if size not in [256, 512, 1024]:
@@ -169,7 +194,10 @@ class ImageGeneratorComponent(CommandProvider):
# TODO: integrate in `forge.llm.providers`(?)
response = OpenAI(
api_key=self.legacy_config.openai_credentials.api_key.get_secret_value()
api_key=self.openai_credentials.api_key.get_secret_value(),
organization=self.openai_credentials.organization.get_secret_value()
if self.openai_credentials.organization
else None,
).images.generate(
prompt=prompt,
n=1,
@@ -208,13 +236,13 @@ class ImageGeneratorComponent(CommandProvider):
"""
# Create a session and set the basic auth if needed
s = requests.Session()
if self.legacy_config.sd_webui_auth:
username, password = self.legacy_config.sd_webui_auth.split(":")
if self.config.sd_webui_auth:
username, password = self.config.sd_webui_auth.get_secret_value().split(":")
s.auth = (username, password or "")
# Generate the images
response = requests.post(
f"{self.legacy_config.sd_webui_url}/sdapi/v1/txt2img",
f"{self.config.sd_webui_url}/sdapi/v1/txt2img",
json={
"prompt": prompt,
"negative_prompt": negative_prompt,

View File

@@ -4,7 +4,6 @@ import click
from forge.agent.protocols import CommandProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.models.json_schema import JSONSchema
from forge.utils.const import ASK_COMMAND
@@ -12,9 +11,6 @@ from forge.utils.const import ASK_COMMAND
class UserInteractionComponent(CommandProvider):
"""Provides commands to interact with the user."""
def __init__(self, config: Config):
self._enabled = not config.noninteractive_mode
def get_commands(self) -> Iterator[Command]:
yield self.ask_user

View File

@@ -22,7 +22,7 @@ class WatchdogComponent(AfterParse[AnyProposal]):
def __init__(
self,
config: "BaseAgentConfiguration",
config: BaseAgentConfiguration,
event_history: EpisodicActionHistory[AnyProposal],
):
self.config = config

View File

@@ -1,30 +1,44 @@
import json
import logging
import time
from typing import Iterator
from typing import Iterator, Optional
from duckduckgo_search import DDGS
from pydantic import BaseModel, SecretStr
from forge.agent.components import ConfigurableComponent
from forge.agent.protocols import CommandProvider, DirectiveProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.models.config import UserConfigurable
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import ConfigurationError
DUCKDUCKGO_MAX_ATTEMPTS = 3
logger = logging.getLogger(__name__)
class WebSearchComponent(DirectiveProvider, CommandProvider):
class WebSearchConfiguration(BaseModel):
google_api_key: Optional[SecretStr] = UserConfigurable(
from_env="GOOGLE_API_KEY", exclude=True
)
google_custom_search_engine_id: Optional[SecretStr] = UserConfigurable(
from_env="GOOGLE_CUSTOM_SEARCH_ENGINE_ID", exclude=True
)
duckduckgo_max_attempts: int = 3
class WebSearchComponent(
DirectiveProvider, CommandProvider, ConfigurableComponent[WebSearchConfiguration]
):
"""Provides commands to search the web."""
def __init__(self, config: Config):
self.legacy_config = config
config_class = WebSearchConfiguration
def __init__(self, config: Optional[WebSearchConfiguration] = None):
ConfigurableComponent.__init__(self, config)
if (
not self.legacy_config.google_api_key
or not self.legacy_config.google_custom_search_engine_id
not self.config.google_api_key
or not self.config.google_custom_search_engine_id
):
logger.info(
"Configure google_api_key and custom_search_engine_id "
@@ -37,10 +51,7 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
def get_commands(self) -> Iterator[Command]:
yield self.web_search
if (
self.legacy_config.google_api_key
and self.legacy_config.google_custom_search_engine_id
):
if self.config.google_api_key and self.config.google_custom_search_engine_id:
yield self.google
@command(
@@ -74,7 +85,7 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
search_results = []
attempts = 0
while attempts < DUCKDUCKGO_MAX_ATTEMPTS:
while attempts < self.config.duckduckgo_max_attempts:
if not query:
return json.dumps(search_results)
@@ -136,17 +147,25 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
from googleapiclient.errors import HttpError
try:
# Get the Google API key and Custom Search Engine ID from the config file
api_key = self.legacy_config.google_api_key
custom_search_engine_id = self.legacy_config.google_custom_search_engine_id
# Should be the case if this command is enabled:
assert self.config.google_api_key
assert self.config.google_custom_search_engine_id
# Initialize the Custom Search API service
service = build("customsearch", "v1", developerKey=api_key)
service = build(
"customsearch",
"v1",
developerKey=self.config.google_api_key.get_secret_value(),
)
# Send the search query and retrieve the results
result = (
service.cse()
.list(q=query, cx=custom_search_engine_id, num=num_results)
.list(
q=query,
cx=self.config.google_custom_search_engine_id.get_secret_value(),
num=num_results,
)
.execute()
)

View File

@@ -3,10 +3,11 @@ import logging
import re
from pathlib import Path
from sys import platform
from typing import Iterator, Type
from typing import Iterator, Literal, Optional, Type
from urllib.request import urlretrieve
from bs4 import BeautifulSoup
from pydantic import BaseModel
from selenium.common.exceptions import WebDriverException
from selenium.webdriver.chrome.options import Options as ChromeOptions
from selenium.webdriver.chrome.service import Service as ChromeDriverService
@@ -27,12 +28,14 @@ from webdriver_manager.chrome import ChromeDriverManager
from webdriver_manager.firefox import GeckoDriverManager
from webdriver_manager.microsoft import EdgeChromiumDriverManager as EdgeDriverManager
from forge.agent.components import ConfigurableComponent
from forge.agent.protocols import CommandProvider, DirectiveProvider
from forge.command import Command, command
from forge.config.config import Config
from forge.content_processing.html import extract_hyperlinks, format_hyperlinks
from forge.content_processing.text import extract_information, summarize_text
from forge.llm.providers import ChatModelInfo, MultiProvider
from forge.llm.providers import MultiProvider
from forge.llm.providers.multi import ModelName
from forge.llm.providers.openai import OpenAIModelName
from forge.models.json_schema import JSONSchema
from forge.utils.exceptions import CommandExecutionError, TooMuchOutputError
from forge.utils.url_validator import validate_url
@@ -51,18 +54,38 @@ class BrowsingError(CommandExecutionError):
"""An error occurred while trying to browse the page"""
class WebSeleniumComponent(DirectiveProvider, CommandProvider):
class WebSeleniumConfiguration(BaseModel):
model_name: ModelName = OpenAIModelName.GPT3
"""Name of the llm model used to read websites"""
web_browser: Literal["chrome", "firefox", "safari", "edge"] = "chrome"
"""Web browser used by Selenium"""
headless: bool = True
"""Run browser in headless mode"""
user_agent: str = (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) "
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"
)
"""User agent used by the browser"""
browse_spacy_language_model: str = "en_core_web_sm"
"""Spacy language model used for chunking text"""
class WebSeleniumComponent(
DirectiveProvider, CommandProvider, ConfigurableComponent[WebSeleniumConfiguration]
):
"""Provides commands to browse the web using Selenium."""
config_class = WebSeleniumConfiguration
def __init__(
self,
config: Config,
llm_provider: MultiProvider,
model_info: ChatModelInfo,
data_dir: Path,
config: Optional[WebSeleniumConfiguration] = None,
):
self.legacy_config = config
ConfigurableComponent.__init__(self, config)
self.llm_provider = llm_provider
self.model_info = model_info
self.data_dir = data_dir
def get_resources(self) -> Iterator[str]:
yield "Ability to read websites."
@@ -129,7 +152,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
"""
driver = None
try:
driver = await self.open_page_in_browser(url, self.legacy_config)
driver = await self.open_page_in_browser(url)
text = self.scrape_text_with_selenium(driver)
links = self.scrape_links_with_selenium(driver, url)
@@ -141,7 +164,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
elif get_raw_content:
if (
output_tokens := self.llm_provider.count_tokens(
text, self.model_info.name
text, self.config.model_name
)
) > MAX_RAW_CONTENT_LENGTH:
oversize_factor = round(output_tokens / MAX_RAW_CONTENT_LENGTH, 1)
@@ -228,7 +251,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
return format_hyperlinks(hyperlinks)
async def open_page_in_browser(self, url: str, config: Config) -> WebDriver:
async def open_page_in_browser(self, url: str) -> WebDriver:
"""Open a browser window and load a web page using Selenium
Params:
@@ -248,11 +271,11 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
"safari": SafariOptions,
}
options: BrowserOptions = options_available[config.selenium_web_browser]()
options.add_argument(f"user-agent={config.user_agent}")
options: BrowserOptions = options_available[self.config.web_browser]()
options.add_argument(f"user-agent={self.config.user_agent}")
if isinstance(options, FirefoxOptions):
if config.selenium_headless:
if self.config.headless:
options.headless = True # type: ignore
options.add_argument("--disable-gpu")
driver = FirefoxDriver(
@@ -274,13 +297,11 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
options.add_argument("--remote-debugging-port=9222")
options.add_argument("--no-sandbox")
if config.selenium_headless:
if self.config.headless:
options.add_argument("--headless=new")
options.add_argument("--disable-gpu")
self._sideload_chrome_extensions(
options, config.app_data_dir / "assets" / "crx"
)
self._sideload_chrome_extensions(options, self.data_dir / "assets" / "crx")
if (chromium_driver_path := Path("/usr/bin/chromedriver")).exists():
chrome_service = ChromeDriverService(str(chromium_driver_path))
@@ -361,7 +382,8 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
text,
topics_of_interest=topics_of_interest,
llm_provider=self.llm_provider,
config=self.legacy_config,
model_name=self.config.model_name,
spacy_model=self.config.browse_spacy_language_model,
)
return "\n".join(f"* {i}" for i in information)
else:
@@ -369,6 +391,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
text,
question=question,
llm_provider=self.llm_provider,
config=self.legacy_config,
model_name=self.config.model_name,
spacy_model=self.config.browse_spacy_language_model,
)
return result

View File

@@ -3,12 +3,10 @@ This module contains configuration models and helpers for AutoGPT Forge.
"""
from .ai_directives import AIDirectives
from .ai_profile import AIProfile
from .config import Config, ConfigBuilder, assert_config_has_required_llm_api_keys
from .base import BaseConfig
__all__ = [
"assert_config_has_required_llm_api_keys",
"AIProfile",
"AIDirectives",
"Config",
"ConfigBuilder",
"BaseConfig",
]

View File

@@ -0,0 +1,16 @@
from forge.file_storage import FileStorageBackendName
from forge.models.config import SystemSettings, UserConfigurable
from forge.speech.say import TTSConfig
class BaseConfig(SystemSettings):
name: str = "Base configuration"
description: str = "Default configuration for forge agent."
# TTS configuration
tts_config: TTSConfig = TTSConfig()
# File storage
file_storage_backend: FileStorageBackendName = UserConfigurable(
default=FileStorageBackendName.LOCAL, from_env="FILE_STORAGE_BACKEND"
)

View File

@@ -1,302 +0,0 @@
"""Configuration class to store the state of bools for different scripts access."""
from __future__ import annotations
import logging
import os
import re
from pathlib import Path
from typing import Any, Optional, Union
from pydantic import SecretStr, validator
import forge
from forge.file_storage import FileStorageBackendName
from forge.llm.providers import CHAT_MODELS, ModelName
from forge.llm.providers.openai import OpenAICredentials, OpenAIModelName
from forge.logging.config import LoggingConfig
from forge.models.config import Configurable, SystemSettings, UserConfigurable
from forge.speech.say import TTSConfig
logger = logging.getLogger(__name__)
PROJECT_ROOT = Path(forge.__file__).parent.parent
AZURE_CONFIG_FILE = Path("azure.yaml")
GPT_4_MODEL = OpenAIModelName.GPT4
GPT_3_MODEL = OpenAIModelName.GPT3
class Config(SystemSettings, arbitrary_types_allowed=True):
name: str = "Auto-GPT configuration"
description: str = "Default configuration for the Auto-GPT application."
########################
# Application Settings #
########################
project_root: Path = PROJECT_ROOT
app_data_dir: Path = project_root / "data"
skip_news: bool = False
skip_reprompt: bool = False
authorise_key: str = UserConfigurable(default="y", from_env="AUTHORISE_COMMAND_KEY")
exit_key: str = UserConfigurable(default="n", from_env="EXIT_KEY")
noninteractive_mode: bool = False
# TTS configuration
logging: LoggingConfig = LoggingConfig()
tts_config: TTSConfig = TTSConfig()
# File storage
file_storage_backend: FileStorageBackendName = UserConfigurable(
default=FileStorageBackendName.LOCAL, from_env="FILE_STORAGE_BACKEND"
)
##########################
# Agent Control Settings #
##########################
# Model configuration
fast_llm: ModelName = UserConfigurable(
default=OpenAIModelName.GPT3,
from_env="FAST_LLM",
)
smart_llm: ModelName = UserConfigurable(
default=OpenAIModelName.GPT4_TURBO,
from_env="SMART_LLM",
)
temperature: float = UserConfigurable(default=0, from_env="TEMPERATURE")
openai_functions: bool = UserConfigurable(
default=False, from_env=lambda: os.getenv("OPENAI_FUNCTIONS", "False") == "True"
)
embedding_model: str = UserConfigurable(
default="text-embedding-3-small", from_env="EMBEDDING_MODEL"
)
browse_spacy_language_model: str = UserConfigurable(
default="en_core_web_sm", from_env="BROWSE_SPACY_LANGUAGE_MODEL"
)
# Run loop configuration
continuous_mode: bool = False
continuous_limit: int = 0
############
# Commands #
############
# General
disabled_commands: list[str] = UserConfigurable(
default_factory=list,
from_env=lambda: _safe_split(os.getenv("DISABLED_COMMANDS")),
)
# File ops
restrict_to_workspace: bool = UserConfigurable(
default=True,
from_env=lambda: os.getenv("RESTRICT_TO_WORKSPACE", "True") == "True",
)
allow_downloads: bool = False
# Shell commands
shell_command_control: str = UserConfigurable(
default="denylist", from_env="SHELL_COMMAND_CONTROL"
)
execute_local_commands: bool = UserConfigurable(
default=False,
from_env=lambda: os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True",
)
shell_denylist: list[str] = UserConfigurable(
default_factory=lambda: ["sudo", "su"],
from_env=lambda: _safe_split(
os.getenv("SHELL_DENYLIST", os.getenv("DENY_COMMANDS"))
),
)
shell_allowlist: list[str] = UserConfigurable(
default_factory=list,
from_env=lambda: _safe_split(
os.getenv("SHELL_ALLOWLIST", os.getenv("ALLOW_COMMANDS"))
),
)
# Text to image
image_provider: Optional[str] = UserConfigurable(from_env="IMAGE_PROVIDER")
huggingface_image_model: str = UserConfigurable(
default="CompVis/stable-diffusion-v1-4", from_env="HUGGINGFACE_IMAGE_MODEL"
)
sd_webui_url: Optional[str] = UserConfigurable(
default="http://localhost:7860", from_env="SD_WEBUI_URL"
)
image_size: int = UserConfigurable(default=256, from_env="IMAGE_SIZE")
# Audio to text
audio_to_text_provider: str = UserConfigurable(
default="huggingface", from_env="AUDIO_TO_TEXT_PROVIDER"
)
huggingface_audio_to_text_model: Optional[str] = UserConfigurable(
from_env="HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
)
# Web browsing
selenium_web_browser: str = UserConfigurable("chrome", from_env="USE_WEB_BROWSER")
selenium_headless: bool = UserConfigurable(
default=True, from_env=lambda: os.getenv("HEADLESS_BROWSER", "True") == "True"
)
user_agent: str = UserConfigurable(
default="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36", # noqa: E501
from_env="USER_AGENT",
)
###############
# Credentials #
###############
# OpenAI
openai_credentials: Optional[OpenAICredentials] = None
azure_config_file: Optional[Path] = UserConfigurable(
default=AZURE_CONFIG_FILE, from_env="AZURE_CONFIG_FILE"
)
# Github
github_api_key: Optional[str] = UserConfigurable(from_env="GITHUB_API_KEY")
github_username: Optional[str] = UserConfigurable(from_env="GITHUB_USERNAME")
# Google
google_api_key: Optional[str] = UserConfigurable(from_env="GOOGLE_API_KEY")
google_custom_search_engine_id: Optional[str] = UserConfigurable(
from_env="GOOGLE_CUSTOM_SEARCH_ENGINE_ID",
)
# Huggingface
huggingface_api_token: Optional[str] = UserConfigurable(
from_env="HUGGINGFACE_API_TOKEN"
)
# Stable Diffusion
sd_webui_auth: Optional[str] = UserConfigurable(from_env="SD_WEBUI_AUTH")
@validator("openai_functions")
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
if v:
smart_llm = values["smart_llm"]
assert CHAT_MODELS[smart_llm].has_function_call_api, (
f"Model {smart_llm} does not support tool calling. "
"Please disable OPENAI_FUNCTIONS or choose a suitable model."
)
return v
class ConfigBuilder(Configurable[Config]):
default_settings = Config()
@classmethod
def build_config_from_env(cls, project_root: Path = PROJECT_ROOT) -> Config:
"""Initialize the Config class"""
config = cls.build_agent_configuration()
config.project_root = project_root
# Make relative paths absolute
for k in {
"azure_config_file", # TODO: move from project root
}:
setattr(config, k, project_root / getattr(config, k))
if (
config.openai_credentials
and config.openai_credentials.api_type == SecretStr("azure")
and (config_file := config.azure_config_file)
):
config.openai_credentials.load_azure_config(config_file)
return config
async def assert_config_has_required_llm_api_keys(config: Config) -> None:
"""
Check if API keys (if required) are set for the configured SMART_LLM and FAST_LLM.
"""
from pydantic import ValidationError
from forge.llm.providers.anthropic import AnthropicModelName
from forge.llm.providers.groq import GroqModelName
if set((config.smart_llm, config.fast_llm)).intersection(AnthropicModelName):
from forge.llm.providers.anthropic import AnthropicCredentials
try:
credentials = AnthropicCredentials.from_env()
except ValidationError as e:
if "api_key" in str(e):
logger.error(
"Set your Anthropic API key in .env or as an environment variable"
)
logger.info(
"For further instructions: "
"https://docs.agpt.co/autogpt/setup/#anthropic"
)
raise ValueError("Anthropic is unavailable: can't load credentials") from e
key_pattern = r"^sk-ant-api03-[\w\-]{95}"
# If key is set, but it looks invalid
if not re.search(key_pattern, credentials.api_key.get_secret_value()):
logger.warning(
"Possibly invalid Anthropic API key! "
f"Configured Anthropic API key does not match pattern '{key_pattern}'. "
"If this is a valid key, please report this warning to the maintainers."
)
if set((config.smart_llm, config.fast_llm)).intersection(GroqModelName):
from groq import AuthenticationError
from forge.llm.providers.groq import GroqProvider
try:
groq = GroqProvider()
await groq.get_available_models()
except ValidationError as e:
if "api_key" not in str(e):
raise
logger.error("Set your Groq API key in .env or as an environment variable")
logger.info(
"For further instructions: https://docs.agpt.co/autogpt/setup/#groq"
)
raise ValueError("Groq is unavailable: can't load credentials")
except AuthenticationError as e:
logger.error("The Groq API key is invalid!")
logger.info(
"For instructions to get and set a new API key: "
"https://docs.agpt.co/autogpt/setup/#groq"
)
raise ValueError("Groq is unavailable: invalid API key") from e
if set((config.smart_llm, config.fast_llm)).intersection(OpenAIModelName):
from openai import AuthenticationError
from forge.llm.providers.openai import OpenAIProvider
try:
openai = OpenAIProvider()
await openai.get_available_models()
except ValidationError as e:
if "api_key" not in str(e):
raise
logger.error(
"Set your OpenAI API key in .env or as an environment variable"
)
logger.info(
"For further instructions: https://docs.agpt.co/autogpt/setup/#openai"
)
raise ValueError("OpenAI is unavailable: can't load credentials")
except AuthenticationError as e:
logger.error("The OpenAI API key is invalid!")
logger.info(
"For instructions to get and set a new API key: "
"https://docs.agpt.co/autogpt/setup/#openai"
)
raise ValueError("OpenAI is unavailable: invalid API key") from e
def _safe_split(s: Union[str, None], sep: str = ",") -> list[str]:
"""Split a string by a separator. Return an empty list if the string is None."""
if s is None:
return []
return s.split(sep)

View File

@@ -3,16 +3,14 @@ from __future__ import annotations
import logging
import math
from typing import TYPE_CHECKING, Iterator, Optional, TypeVar
from typing import Iterator, Optional, TypeVar
import spacy
if TYPE_CHECKING:
from forge.config.config import Config
from forge.json.parsing import extract_list_from_json
from forge.llm.prompting import ChatPrompt
from forge.llm.providers import ChatMessage, ModelTokenizer, MultiProvider
from forge.llm.providers.multi import ModelName
logger = logging.getLogger(__name__)
@@ -57,7 +55,8 @@ def chunk_content(
async def summarize_text(
text: str,
llm_provider: MultiProvider,
config: Config,
model_name: ModelName,
spacy_model: str = "en_core_web_sm",
question: Optional[str] = None,
instruction: Optional[str] = None,
) -> tuple[str, list[tuple[str, str]]]:
@@ -82,7 +81,8 @@ async def summarize_text(
text=text,
instruction=instruction,
llm_provider=llm_provider,
config=config,
model_name=model_name,
spacy_model=spacy_model,
)
@@ -90,7 +90,8 @@ async def extract_information(
source_text: str,
topics_of_interest: list[str],
llm_provider: MultiProvider,
config: Config,
model_name: ModelName,
spacy_model: str = "en_core_web_sm",
) -> list[str]:
fmt_topics_list = "\n".join(f"* {topic}." for topic in topics_of_interest)
instruction = (
@@ -106,7 +107,8 @@ async def extract_information(
instruction=instruction,
output_type=list[str],
llm_provider=llm_provider,
config=config,
model_name=model_name,
spacy_model=spacy_model,
)
@@ -114,7 +116,8 @@ async def _process_text(
text: str,
instruction: str,
llm_provider: MultiProvider,
config: Config,
model_name: ModelName,
spacy_model: str = "en_core_web_sm",
output_type: type[str | list[str]] = str,
) -> tuple[str, list[tuple[str, str]]] | list[str]:
"""Process text using the OpenAI API for summarization or information extraction
@@ -123,7 +126,8 @@ async def _process_text(
text (str): The text to process.
instruction (str): Additional instruction for processing.
llm_provider: LLM provider to use.
config (Config): The global application config.
model_name: The name of the llm model to use.
spacy_model: The spaCy model to use for sentence splitting.
output_type: `str` for summaries or `list[str]` for piece-wise info extraction.
Returns:
@@ -133,13 +137,11 @@ async def _process_text(
if not text.strip():
raise ValueError("No content")
model = config.fast_llm
text_tlength = llm_provider.count_tokens(text, model)
text_tlength = llm_provider.count_tokens(text, model_name)
logger.debug(f"Text length: {text_tlength} tokens")
max_result_tokens = 500
max_chunk_length = llm_provider.get_token_limit(model) - max_result_tokens - 50
max_chunk_length = llm_provider.get_token_limit(model_name) - max_result_tokens - 50
logger.debug(f"Max chunk length: {max_chunk_length} tokens")
if text_tlength < max_chunk_length:
@@ -157,7 +159,7 @@ async def _process_text(
response = await llm_provider.create_chat_completion(
model_prompt=prompt.messages,
model_name=model,
model_name=model_name,
temperature=0.5,
max_output_tokens=max_result_tokens,
completion_parser=lambda s: (
@@ -182,9 +184,9 @@ async def _process_text(
chunks = list(
split_text(
text,
config=config,
max_chunk_length=max_chunk_length,
tokenizer=llm_provider.get_tokenizer(model),
tokenizer=llm_provider.get_tokenizer(model_name),
spacy_model=spacy_model,
)
)
@@ -196,7 +198,8 @@ async def _process_text(
instruction=instruction,
output_type=output_type,
llm_provider=llm_provider,
config=config,
model_name=model_name,
spacy_model=spacy_model,
)
processed_results.extend(
chunk_result if output_type == list[str] else [chunk_result]
@@ -212,7 +215,8 @@ async def _process_text(
"Combine these partial summaries into one."
),
llm_provider=llm_provider,
config=config,
model_name=model_name,
spacy_model=spacy_model,
)
return summary.strip(), [
(processed_results[i], chunks[i][0]) for i in range(0, len(chunks))
@@ -221,9 +225,9 @@ async def _process_text(
def split_text(
text: str,
config: Config,
max_chunk_length: int,
tokenizer: ModelTokenizer,
spacy_model: str = "en_core_web_sm",
with_overlap: bool = True,
) -> Iterator[tuple[str, int]]:
"""
@@ -231,7 +235,7 @@ def split_text(
Args:
text (str): The text to split.
config (Config): Config object containing the Spacy model setting.
spacy_model (str): The spaCy model to use for sentence splitting.
max_chunk_length (int, optional): The maximum length of a chunk.
tokenizer (ModelTokenizer): Tokenizer to use for determining chunk length.
with_overlap (bool, optional): Whether to allow overlap between chunks.
@@ -251,7 +255,7 @@ def split_text(
n_chunks = math.ceil(text_length / max_chunk_length)
target_chunk_length = math.ceil(text_length / n_chunks)
nlp: spacy.language.Language = spacy.load(config.browse_spacy_language_model)
nlp: spacy.language.Language = spacy.load(spacy_model)
nlp.add_pipe("sentencizer")
doc = nlp(text)
sentences = [sentence.text.strip() for sentence in doc.sents]

View File

@@ -16,6 +16,7 @@ def UserConfigurable(
default_factory: Optional[Callable[[], T]] = None,
from_env: Optional[str | Callable[[], T | None]] = None,
description: str = "",
exclude: bool = False,
**kwargs,
) -> T:
# TODO: use this to auto-generate docs for the application configuration
@@ -25,6 +26,7 @@ def UserConfigurable(
default_factory=default_factory,
from_env=from_env,
description=description,
exclude=exclude,
**kwargs,
user_configurable=True,
)