mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(forge): Component-specific configuration (#7170)
Remove many env vars and use component-level configuration that could be loaded from file instead. ### Changed - `BaseAgent` provides `serialize_configs` and `deserialize_configs` that can save and load all component configuration as json `str`. Deserialized components/values overwrite existing values, so not all values need to be present in the serialized config. - Decoupled `forge/content_processing/text.py` from `Config` - Kept `execute_local_commands` in `Config` because it's needed to know if OS info should be included in the prompt - Updated docs to reflect changes - Renamed `Config` to `AppConfig` ### Added - Added `ConfigurableComponent` class for components and following configs: - `ActionHistoryConfiguration` - `CodeExecutorConfiguration` - `FileManagerConfiguration` - now file manager allows to have multiple agents using the same workspace - `GitOperationsConfiguration` - `ImageGeneratorConfiguration` - `WebSearchConfiguration` - `WebSeleniumConfiguration` - `BaseConfig` in `forge` and moved `Config` (now inherits from `BaseConfig`) back to `autogpt` - Required `config_class` attribute for the `ConfigurableComponent` class that should be set to configuration class for a component `--component-config-file` CLI option and `COMPONENT_CONFIG_FILE` env var and field in `Config`. This option allows to load configuration from a specific file, CLI option takes precedence over env var. - Added comments to config models ### Removed - Unused `change_agent_id` method from `FileManagerComponent` - Unused `allow_downloads` from `Config` and CLI options (it should be in web component config if needed) - CLI option `--browser-name` (the option is inside `WebSeleniumConfiguration`) - Unused `workspace_directory` from CLI options - No longer needed variables from `Config` and docs - Unused fields from `Config`: `image_size`, `audio_to_text_provider`, `huggingface_audio_to_text_model` - Removed `files` and `workspace` class attributes from `FileManagerComponent`
This commit is contained in:
committed by
GitHub
parent
02dc198a9f
commit
c19ab2b24f
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import (
|
||||
@@ -18,12 +19,13 @@ from typing import (
|
||||
)
|
||||
|
||||
from colorama import Fore
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, parse_raw_as, validator
|
||||
|
||||
from forge.agent import protocols
|
||||
from forge.agent.components import (
|
||||
AgentComponent,
|
||||
ComponentEndpointError,
|
||||
ConfigurableComponent,
|
||||
EndpointPipelineError,
|
||||
)
|
||||
from forge.config.ai_directives import AIDirectives
|
||||
@@ -45,6 +47,11 @@ DEFAULT_TRIGGERING_PROMPT = (
|
||||
)
|
||||
|
||||
|
||||
# HACK: This is a workaround wrapper to de/serialize component configs until pydantic v2
|
||||
class ModelContainer(BaseModel):
|
||||
models: dict[str, BaseModel]
|
||||
|
||||
|
||||
class BaseAgentConfiguration(SystemConfiguration):
|
||||
allow_fs_access: bool = UserConfigurable(default=False)
|
||||
|
||||
@@ -82,9 +89,6 @@ class BaseAgentConfiguration(SystemConfiguration):
|
||||
defaults to 75% of `llm.max_tokens`.
|
||||
"""
|
||||
|
||||
summary_max_tlength: Optional[int] = None
|
||||
# TODO: move to ActionHistoryConfiguration
|
||||
|
||||
@validator("use_functions_api")
|
||||
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
|
||||
if v:
|
||||
@@ -272,6 +276,30 @@ class BaseAgent(Generic[AnyProposal], metaclass=AgentMeta):
|
||||
raise e
|
||||
return method_result
|
||||
|
||||
def dump_component_configs(self) -> str:
|
||||
configs = {}
|
||||
for component in self.components:
|
||||
if isinstance(component, ConfigurableComponent):
|
||||
config_type_name = component.config.__class__.__name__
|
||||
configs[config_type_name] = component.config
|
||||
data = ModelContainer(models=configs).json()
|
||||
raw = parse_raw_as(dict[str, dict[str, Any]], data)
|
||||
return json.dumps(raw["models"], indent=4)
|
||||
|
||||
def load_component_configs(self, serialized_configs: str):
|
||||
configs_dict = parse_raw_as(dict[str, dict[str, Any]], serialized_configs)
|
||||
|
||||
for component in self.components:
|
||||
if not isinstance(component, ConfigurableComponent):
|
||||
continue
|
||||
config_type = type(component.config)
|
||||
config_type_name = config_type.__name__
|
||||
if config_type_name in configs_dict:
|
||||
# Parse the serialized data and update the existing config
|
||||
updated_data = configs_dict[config_type_name]
|
||||
data = {**component.config.dict(), **updated_data}
|
||||
component.config = component.config.__class__(**data)
|
||||
|
||||
def _collect_components(self):
|
||||
components = [
|
||||
getattr(self, attr)
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Callable, TypeVar
|
||||
from typing import Callable, ClassVar, Generic, Optional, TypeVar
|
||||
|
||||
T = TypeVar("T", bound="AgentComponent")
|
||||
from pydantic import BaseModel
|
||||
|
||||
from forge.models.config import _update_user_config_from_env, deep_update
|
||||
|
||||
AC = TypeVar("AC", bound="AgentComponent")
|
||||
BM = TypeVar("BM", bound=BaseModel)
|
||||
|
||||
|
||||
class AgentComponent(ABC):
|
||||
@@ -24,7 +29,7 @@ class AgentComponent(ABC):
|
||||
"""Return the reason this component is disabled."""
|
||||
return self._disabled_reason
|
||||
|
||||
def run_after(self: T, *components: type[AgentComponent] | AgentComponent) -> T:
|
||||
def run_after(self: AC, *components: type[AgentComponent] | AgentComponent) -> AC:
|
||||
"""Set the components that this component should run after."""
|
||||
for component in components:
|
||||
t = component if isinstance(component, type) else type(component)
|
||||
@@ -33,6 +38,39 @@ class AgentComponent(ABC):
|
||||
return self
|
||||
|
||||
|
||||
class ConfigurableComponent(ABC, Generic[BM]):
|
||||
"""A component that can be configured with a Pydantic model."""
|
||||
|
||||
config_class: ClassVar[type[BM]] # type: ignore
|
||||
|
||||
def __init__(self, configuration: Optional[BM]):
|
||||
self._config: Optional[BM] = None
|
||||
if configuration is not None:
|
||||
self.config = configuration
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if getattr(cls, "config_class", None) is None:
|
||||
raise NotImplementedError(
|
||||
f"ConfigurableComponent subclass {cls.__name__} "
|
||||
"must define config_class class attribute."
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self) -> BM:
|
||||
if not hasattr(self, "_config") or self._config is None:
|
||||
self.config = self.config_class()
|
||||
return self._config # type: ignore
|
||||
|
||||
@config.setter
|
||||
def config(self, config: BM):
|
||||
if not hasattr(self, "_config") or self._config is None:
|
||||
# Load configuration from environment variables
|
||||
updated = _update_user_config_from_env(config)
|
||||
config = self.config_class(**deep_update(config.dict(), updated))
|
||||
self._config = config
|
||||
|
||||
|
||||
class ComponentEndpointError(Exception):
|
||||
"""Error of a single protocol method on a component."""
|
||||
|
||||
|
||||
@@ -30,6 +30,48 @@ class MyAgent(BaseAgent):
|
||||
self.some_component = SomeComponent(self.hello_component)
|
||||
```
|
||||
|
||||
## Component configuration
|
||||
|
||||
Each component can have its own configuration defined using a regular pydantic `BaseModel`.
|
||||
To ensure the configuration is loaded from the file correctly, the component must inherit from `ConfigurableComponent[T]` where `T` is the configuration model it uses.
|
||||
`ConfigurableComponent` provides a `config` attribute that holds the configuration instance.
|
||||
It's possible to either set the `config` attribute directly or pass the configuration instance to the component's constructor.
|
||||
Extra configuration (i.e. for components that are not part of the agent) can be passed and will be silently ignored. Extra config won't be applied even if the component is added later.
|
||||
To see the configuration of built-in components visit [Built-in Components](./built-in-components.md).
|
||||
|
||||
```py
|
||||
from pydantic import BaseModel
|
||||
from forge.agent.components import ConfigurableComponent
|
||||
|
||||
class MyConfig(BaseModel):
|
||||
some_value: str
|
||||
|
||||
class MyComponent(AgentComponent, ConfigurableComponent[MyConfig]):
|
||||
def __init__(self, config: MyConfig):
|
||||
super().__init__(config)
|
||||
# This has the same effect as above:
|
||||
# self.config = config
|
||||
|
||||
def get_some_value(self) -> str:
|
||||
# Access the configuration like a regular model
|
||||
return self.config.some_value
|
||||
```
|
||||
|
||||
### Sensitive information
|
||||
|
||||
While it's possible to pass sensitive data directly in code to the configuration it's recommended to use `UserConfigurable(from_env="ENV_VAR_NAME", exclude=True)` field for sensitive data like API keys.
|
||||
The data will be loaded from the environment variable but keep in mind that value passed in code takes precedence.
|
||||
All fields, even excluded ones (`exclude=True`) will be loaded when the configuration is loaded from the file.
|
||||
Exclusion allows you to skip them during *serialization*, non excluded `SecretStr` will be serialized literally as a `"**********"` string.
|
||||
|
||||
```py
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from forge.models.config import UserConfigurable
|
||||
|
||||
class SensitiveConfig(BaseModel):
|
||||
api_key: SecretStr = UserConfigurable(from_env="API_KEY", exclude=True)
|
||||
```
|
||||
|
||||
## Ordering components
|
||||
|
||||
The execution order of components is important because some may depend on the results of the previous ones.
|
||||
@@ -72,6 +114,7 @@ class MyAgent(Agent):
|
||||
## Disabling components
|
||||
|
||||
You can control which components are enabled by setting their `_enabled` attribute.
|
||||
Components are *enabled* by default.
|
||||
Either provide a `bool` value or a `Callable[[], bool]`, will be checked each time
|
||||
the component is about to be executed. This way you can dynamically enable or disable
|
||||
components based on some conditions.
|
||||
|
||||
@@ -1,38 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Iterator, Optional
|
||||
from typing import Callable, Iterator, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from forge.agent.components import ConfigurableComponent
|
||||
from forge.agent.protocols import AfterExecute, AfterParse, MessageProvider
|
||||
from forge.llm.prompting.utils import indent
|
||||
from forge.llm.providers import ChatMessage, MultiProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.config.config import Config
|
||||
from forge.llm.providers.multi import ModelName
|
||||
from forge.llm.providers.openai import OpenAIModelName
|
||||
|
||||
from .model import ActionResult, AnyProposal, Episode, EpisodicActionHistory
|
||||
|
||||
|
||||
class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExecute):
|
||||
class ActionHistoryConfiguration(BaseModel):
|
||||
model_name: ModelName = OpenAIModelName.GPT3
|
||||
"""Name of the llm model used to compress the history"""
|
||||
max_tokens: int = 1024
|
||||
"""Maximum number of tokens to use up with generated history messages"""
|
||||
spacy_language_model: str = "en_core_web_sm"
|
||||
"""Language model used for summary chunking using spacy"""
|
||||
|
||||
|
||||
class ActionHistoryComponent(
|
||||
MessageProvider,
|
||||
AfterParse[AnyProposal],
|
||||
AfterExecute,
|
||||
ConfigurableComponent[ActionHistoryConfiguration],
|
||||
):
|
||||
"""Keeps track of the event history and provides a summary of the steps."""
|
||||
|
||||
config_class = ActionHistoryConfiguration
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_history: EpisodicActionHistory[AnyProposal],
|
||||
max_tokens: int,
|
||||
count_tokens: Callable[[str], int],
|
||||
legacy_config: Config,
|
||||
llm_provider: MultiProvider,
|
||||
config: Optional[ActionHistoryConfiguration] = None,
|
||||
) -> None:
|
||||
ConfigurableComponent.__init__(self, config)
|
||||
self.event_history = event_history
|
||||
self.max_tokens = max_tokens
|
||||
self.count_tokens = count_tokens
|
||||
self.legacy_config = legacy_config
|
||||
self.llm_provider = llm_provider
|
||||
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
if progress := self._compile_progress(
|
||||
self.event_history.episodes,
|
||||
self.max_tokens,
|
||||
self.config.max_tokens,
|
||||
self.count_tokens,
|
||||
):
|
||||
yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}")
|
||||
@@ -43,7 +59,7 @@ class ActionHistoryComponent(MessageProvider, AfterParse[AnyProposal], AfterExec
|
||||
async def after_execute(self, result: ActionResult) -> None:
|
||||
self.event_history.register_result(result)
|
||||
await self.event_history.handle_compression(
|
||||
self.llm_provider, self.legacy_config
|
||||
self.llm_provider, self.config.model_name, self.config.spacy_language_model
|
||||
)
|
||||
|
||||
def _compile_progress(
|
||||
|
||||
@@ -8,11 +8,11 @@ from pydantic.generics import GenericModel
|
||||
|
||||
from forge.content_processing.text import summarize_text
|
||||
from forge.llm.prompting.utils import format_numbered_list, indent
|
||||
from forge.llm.providers.multi import ModelName
|
||||
from forge.models.action import ActionResult, AnyProposal
|
||||
from forge.models.utils import ModelWithSummary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.config.config import Config
|
||||
from forge.llm.providers import MultiProvider
|
||||
|
||||
|
||||
@@ -108,7 +108,10 @@ class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
|
||||
self.cursor = len(self.episodes)
|
||||
|
||||
async def handle_compression(
|
||||
self, llm_provider: MultiProvider, app_config: Config
|
||||
self,
|
||||
llm_provider: MultiProvider,
|
||||
model_name: ModelName,
|
||||
spacy_model: str,
|
||||
) -> None:
|
||||
"""Compresses each episode in the action history using an LLM.
|
||||
|
||||
@@ -131,7 +134,8 @@ class EpisodicActionHistory(GenericModel, Generic[AnyProposal]):
|
||||
episode.format(),
|
||||
instruction=compress_instruction,
|
||||
llm_provider=llm_provider,
|
||||
config=app_config,
|
||||
model_name=model_name,
|
||||
spacy_model=spacy_model,
|
||||
)
|
||||
for episode in episodes_to_summarize
|
||||
]
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
from .code_executor import (
|
||||
ALLOWLIST_CONTROL,
|
||||
DENYLIST_CONTROL,
|
||||
CodeExecutionError,
|
||||
CodeExecutorComponent,
|
||||
)
|
||||
from .code_executor import CodeExecutionError, CodeExecutorComponent
|
||||
|
||||
__all__ = [
|
||||
"ALLOWLIST_CONTROL",
|
||||
|
||||
@@ -5,16 +5,16 @@ import shlex
|
||||
import string
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Literal, Optional
|
||||
|
||||
import docker
|
||||
from docker.errors import DockerException, ImageNotFound, NotFound
|
||||
from docker.models.containers import Container as DockerContainer
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from forge.agent import BaseAgentSettings
|
||||
from forge.agent.components import ConfigurableComponent
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.file_storage import FileStorage
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import (
|
||||
@@ -25,9 +25,6 @@ from forge.utils.exceptions import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALLOWLIST_CONTROL = "allowlist"
|
||||
DENYLIST_CONTROL = "denylist"
|
||||
|
||||
|
||||
def we_are_running_in_a_docker_container() -> bool:
|
||||
"""Check if we are running in a Docker container
|
||||
@@ -56,15 +53,45 @@ class CodeExecutionError(CommandExecutionError):
|
||||
"""The operation (an attempt to run arbitrary code) returned an error"""
|
||||
|
||||
|
||||
class CodeExecutorComponent(CommandProvider):
|
||||
class CodeExecutorConfiguration(BaseModel):
|
||||
execute_local_commands: bool = False
|
||||
"""Enable shell command execution"""
|
||||
shell_command_control: Literal["allowlist", "denylist"] = "allowlist"
|
||||
"""Controls which list is used"""
|
||||
shell_allowlist: list[str] = Field(default_factory=list)
|
||||
"""List of allowed shell commands"""
|
||||
shell_denylist: list[str] = Field(default_factory=list)
|
||||
"""List of prohibited shell commands"""
|
||||
docker_container_name: str = "agent_sandbox"
|
||||
"""Name of the Docker container used for code execution"""
|
||||
|
||||
|
||||
class CodeExecutorComponent(
|
||||
CommandProvider, ConfigurableComponent[CodeExecutorConfiguration]
|
||||
):
|
||||
"""Provides commands to execute Python code and shell commands."""
|
||||
|
||||
config_class = CodeExecutorConfiguration
|
||||
|
||||
def __init__(
|
||||
self, workspace: FileStorage, state: BaseAgentSettings, config: Config
|
||||
self,
|
||||
workspace: FileStorage,
|
||||
config: Optional[CodeExecutorConfiguration] = None,
|
||||
):
|
||||
ConfigurableComponent.__init__(self, config)
|
||||
self.workspace = workspace
|
||||
self.state = state
|
||||
self.legacy_config = config
|
||||
|
||||
# Change container name if it's empty or default to prevent different agents
|
||||
# from using the same container
|
||||
default_container_name = self.config.__fields__["docker_container_name"].default
|
||||
if (
|
||||
not self.config.docker_container_name
|
||||
or self.config.docker_container_name == default_container_name
|
||||
):
|
||||
random_suffix = "".join(random.choices(string.ascii_lowercase, k=8))
|
||||
self.config.docker_container_name = (
|
||||
f"{default_container_name}_{random_suffix}"
|
||||
)
|
||||
|
||||
if not we_are_running_in_a_docker_container() and not is_docker_available():
|
||||
logger.info(
|
||||
@@ -72,7 +99,7 @@ class CodeExecutorComponent(CommandProvider):
|
||||
"The code execution commands will not be available."
|
||||
)
|
||||
|
||||
if not self.legacy_config.execute_local_commands:
|
||||
if not self.config.execute_local_commands:
|
||||
logger.info(
|
||||
"Local shell commands are disabled. To enable them,"
|
||||
" set EXECUTE_LOCAL_COMMANDS to 'True' in your config file."
|
||||
@@ -83,7 +110,7 @@ class CodeExecutorComponent(CommandProvider):
|
||||
yield self.execute_python_code
|
||||
yield self.execute_python_file
|
||||
|
||||
if self.legacy_config.execute_local_commands:
|
||||
if self.config.execute_local_commands:
|
||||
yield self.execute_shell
|
||||
yield self.execute_shell_popen
|
||||
|
||||
@@ -192,7 +219,7 @@ class CodeExecutorComponent(CommandProvider):
|
||||
logger.debug("App is not running in a Docker container")
|
||||
return self._run_python_code_in_docker(file_path, args)
|
||||
|
||||
def validate_command(self, command_line: str, config: Config) -> tuple[bool, bool]:
|
||||
def validate_command(self, command_line: str) -> tuple[bool, bool]:
|
||||
"""Check whether a command is allowed and whether it may be executed in a shell.
|
||||
|
||||
If shell command control is enabled, we disallow executing in a shell, because
|
||||
@@ -211,10 +238,10 @@ class CodeExecutorComponent(CommandProvider):
|
||||
|
||||
command_name = shlex.split(command_line)[0]
|
||||
|
||||
if config.shell_command_control == ALLOWLIST_CONTROL:
|
||||
return command_name in config.shell_allowlist, False
|
||||
elif config.shell_command_control == DENYLIST_CONTROL:
|
||||
return command_name not in config.shell_denylist, False
|
||||
if self.config.shell_command_control == "allowlist":
|
||||
return command_name in self.config.shell_allowlist, False
|
||||
elif self.config.shell_command_control == "denylist":
|
||||
return command_name not in self.config.shell_denylist, False
|
||||
else:
|
||||
return True, True
|
||||
|
||||
@@ -238,9 +265,7 @@ class CodeExecutorComponent(CommandProvider):
|
||||
Returns:
|
||||
str: The output of the command
|
||||
"""
|
||||
allow_execute, allow_shell = self.validate_command(
|
||||
command_line, self.legacy_config
|
||||
)
|
||||
allow_execute, allow_shell = self.validate_command(command_line)
|
||||
if not allow_execute:
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
@@ -287,9 +312,7 @@ class CodeExecutorComponent(CommandProvider):
|
||||
Returns:
|
||||
str: Description of the fact that the process started and its id
|
||||
"""
|
||||
allow_execute, allow_shell = self.validate_command(
|
||||
command_line, self.legacy_config
|
||||
)
|
||||
allow_execute, allow_shell = self.validate_command(command_line)
|
||||
if not allow_execute:
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
@@ -320,12 +343,10 @@ class CodeExecutorComponent(CommandProvider):
|
||||
"""Run a Python script in a Docker container"""
|
||||
file_path = self.workspace.get_path(filename)
|
||||
try:
|
||||
assert self.state.agent_id, "Need Agent ID to attach Docker container"
|
||||
|
||||
client = docker.from_env()
|
||||
image_name = "python:3-alpine"
|
||||
container_is_fresh = False
|
||||
container_name = f"{self.state.agent_id}_sandbox"
|
||||
container_name = self.config.docker_container_name
|
||||
with self.workspace.mount() as local_path:
|
||||
try:
|
||||
container: DockerContainer = client.containers.get(
|
||||
|
||||
@@ -3,7 +3,10 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from forge.agent import BaseAgentSettings
|
||||
from forge.agent.components import ConfigurableComponent
|
||||
from forge.agent.protocols import CommandProvider, DirectiveProvider
|
||||
from forge.command import Command, command
|
||||
from forge.file_storage.base import FileStorage
|
||||
@@ -13,67 +16,89 @@ from forge.utils.file_operations import decode_textual_file
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileManagerComponent(DirectiveProvider, CommandProvider):
|
||||
class FileManagerConfiguration(BaseModel):
|
||||
storage_path: str
|
||||
"""Path to agent files, e.g. state"""
|
||||
workspace_path: str
|
||||
"""Path to files that agent has access to"""
|
||||
|
||||
class Config:
|
||||
# Prevent mutation of the configuration
|
||||
# as this wouldn't be reflected in the file storage
|
||||
allow_mutation = False
|
||||
|
||||
|
||||
class FileManagerComponent(
|
||||
DirectiveProvider, CommandProvider, ConfigurableComponent[FileManagerConfiguration]
|
||||
):
|
||||
"""
|
||||
Adds general file manager (e.g. Agent state),
|
||||
workspace manager (e.g. Agent output files) support and
|
||||
commands to perform operations on files and folders.
|
||||
"""
|
||||
|
||||
files: FileStorage
|
||||
"""Agent-related files, e.g. state, logs.
|
||||
Use `workspace` to access the agent's workspace files."""
|
||||
|
||||
workspace: FileStorage
|
||||
"""Workspace that the agent has access to, e.g. for reading/writing files.
|
||||
Use `files` to access agent-related files, e.g. state, logs."""
|
||||
config_class = FileManagerConfiguration
|
||||
|
||||
STATE_FILE = "state.json"
|
||||
"""The name of the file where the agent's state is stored."""
|
||||
|
||||
def __init__(self, state: BaseAgentSettings, file_storage: FileStorage):
|
||||
self.state = state
|
||||
def __init__(
|
||||
self,
|
||||
file_storage: FileStorage,
|
||||
agent_state: BaseAgentSettings,
|
||||
config: Optional[FileManagerConfiguration] = None,
|
||||
):
|
||||
"""Initialise the FileManagerComponent.
|
||||
Either `agent_id` or `config` must be provided.
|
||||
|
||||
if not state.agent_id:
|
||||
Args:
|
||||
file_storage (FileStorage): The file storage instance to use.
|
||||
state (BaseAgentSettings): The agent's state.
|
||||
config (FileManagerConfiguration, optional): The configuration for
|
||||
the file manager. Defaults to None.
|
||||
"""
|
||||
if not agent_state.agent_id:
|
||||
raise ValueError("Agent must have an ID.")
|
||||
|
||||
self.files = file_storage.clone_with_subroot(f"agents/{state.agent_id}/")
|
||||
self.workspace = file_storage.clone_with_subroot(
|
||||
f"agents/{state.agent_id}/workspace"
|
||||
)
|
||||
self.agent_state = agent_state
|
||||
|
||||
if not config:
|
||||
storage_path = f"agents/{self.agent_state.agent_id}/"
|
||||
workspace_path = f"agents/{self.agent_state.agent_id}/workspace"
|
||||
ConfigurableComponent.__init__(
|
||||
self,
|
||||
FileManagerConfiguration(
|
||||
storage_path=storage_path, workspace_path=workspace_path
|
||||
),
|
||||
)
|
||||
else:
|
||||
ConfigurableComponent.__init__(self, config)
|
||||
|
||||
self.storage = file_storage.clone_with_subroot(self.config.storage_path)
|
||||
"""Agent-related files, e.g. state, logs.
|
||||
Use `workspace` to access the agent's workspace files."""
|
||||
self.workspace = file_storage.clone_with_subroot(self.config.workspace_path)
|
||||
"""Workspace that the agent has access to, e.g. for reading/writing files.
|
||||
Use `storage` to access agent-related files, e.g. state, logs."""
|
||||
self._file_storage = file_storage
|
||||
|
||||
async def save_state(self, save_as: Optional[str] = None) -> None:
|
||||
"""Save the agent's state to the state file."""
|
||||
state: BaseAgentSettings = getattr(self, "state")
|
||||
if save_as:
|
||||
temp_id = state.agent_id
|
||||
state.agent_id = save_as
|
||||
self._file_storage.make_dir(f"agents/{save_as}")
|
||||
async def save_state(self, save_as_id: Optional[str] = None) -> None:
|
||||
"""Save the agent's data and state."""
|
||||
if save_as_id:
|
||||
self._file_storage.make_dir(f"agents/{save_as_id}")
|
||||
# Save state
|
||||
await self._file_storage.write_file(
|
||||
f"agents/{save_as}/{self.STATE_FILE}", state.json()
|
||||
f"agents/{save_as_id}/{self.STATE_FILE}", self.agent_state.json()
|
||||
)
|
||||
# Copy workspace
|
||||
self._file_storage.copy(
|
||||
f"agents/{temp_id}/workspace",
|
||||
f"agents/{save_as}/workspace",
|
||||
self.config.workspace_path,
|
||||
f"agents/{save_as_id}/workspace",
|
||||
)
|
||||
state.agent_id = temp_id
|
||||
else:
|
||||
await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
|
||||
|
||||
def change_agent_id(self, new_id: str):
|
||||
"""Change the agent's ID and update the file storage accordingly."""
|
||||
state: BaseAgentSettings = getattr(self, "state")
|
||||
# Rename the agent's files and workspace
|
||||
self._file_storage.rename(f"agents/{state.agent_id}", f"agents/{new_id}")
|
||||
# Update the file storage objects
|
||||
self.files = self._file_storage.clone_with_subroot(f"agents/{new_id}/")
|
||||
self.workspace = self._file_storage.clone_with_subroot(
|
||||
f"agents/{new_id}/workspace"
|
||||
)
|
||||
state.agent_id = new_id
|
||||
await self.storage.write_file(
|
||||
self.storage.root / self.STATE_FILE, self.agent_state.json()
|
||||
)
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield "The ability to read and write files."
|
||||
|
||||
@@ -1,23 +1,36 @@
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from git.repo import Repo
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from forge.agent.components import ConfigurableComponent
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.models.config import UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import CommandExecutionError
|
||||
from forge.utils.url_validator import validate_url
|
||||
|
||||
|
||||
class GitOperationsComponent(CommandProvider):
|
||||
class GitOperationsConfiguration(BaseModel):
|
||||
github_username: Optional[str] = UserConfigurable(from_env="GITHUB_USERNAME")
|
||||
github_api_key: Optional[SecretStr] = UserConfigurable(
|
||||
from_env="GITHUB_API_KEY", exclude=True
|
||||
)
|
||||
|
||||
|
||||
class GitOperationsComponent(
|
||||
CommandProvider, ConfigurableComponent[GitOperationsConfiguration]
|
||||
):
|
||||
"""Provides commands to perform Git operations."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self._enabled = bool(config.github_username and config.github_api_key)
|
||||
config_class = GitOperationsConfiguration
|
||||
|
||||
def __init__(self, config: Optional[GitOperationsConfiguration] = None):
|
||||
ConfigurableComponent.__init__(self, config)
|
||||
self._enabled = bool(self.config.github_username and self.config.github_api_key)
|
||||
self._disabled_reason = "Configure github_username and github_api_key."
|
||||
self.legacy_config = config
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.clone_repository
|
||||
@@ -48,9 +61,13 @@ class GitOperationsComponent(CommandProvider):
|
||||
str: The result of the clone operation.
|
||||
"""
|
||||
split_url = url.split("//")
|
||||
auth_repo_url = (
|
||||
f"//{self.legacy_config.github_username}:"
|
||||
f"{self.legacy_config.github_api_key}@".join(split_url)
|
||||
api_key = (
|
||||
self.config.github_api_key.get_secret_value()
|
||||
if self.config.github_api_key
|
||||
else None
|
||||
)
|
||||
auth_repo_url = f"//{self.config.github_username}:" f"{api_key}@".join(
|
||||
split_url
|
||||
)
|
||||
try:
|
||||
Repo.clone_from(url=auth_repo_url, to_path=clone_path)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""Commands to generate images based on text input"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
@@ -7,35 +5,61 @@ import time
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Literal, Optional
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from forge.agent.components import ConfigurableComponent
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.file_storage import FileStorage
|
||||
from forge.llm.providers.openai import OpenAICredentials
|
||||
from forge.models.config import UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageGeneratorComponent(CommandProvider):
|
||||
class ImageGeneratorConfiguration(BaseModel):
|
||||
image_provider: Literal["dalle", "huggingface", "sdwebui"] = "dalle"
|
||||
huggingface_image_model: str = "CompVis/stable-diffusion-v1-4"
|
||||
huggingface_api_token: Optional[SecretStr] = UserConfigurable(
|
||||
from_env="HUGGINGFACE_API_TOKEN", exclude=True
|
||||
)
|
||||
sd_webui_url: str = "http://localhost:7860"
|
||||
sd_webui_auth: Optional[SecretStr] = UserConfigurable(
|
||||
from_env="SD_WEBUI_AUTH", exclude=True
|
||||
)
|
||||
|
||||
|
||||
class ImageGeneratorComponent(
|
||||
CommandProvider, ConfigurableComponent[ImageGeneratorConfiguration]
|
||||
):
|
||||
"""A component that provides commands to generate images from text prompts."""
|
||||
|
||||
def __init__(self, workspace: FileStorage, config: Config):
|
||||
self._enabled = bool(config.image_provider)
|
||||
config_class = ImageGeneratorConfiguration
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: FileStorage,
|
||||
config: Optional[ImageGeneratorConfiguration] = None,
|
||||
openai_credentials: Optional[OpenAICredentials] = None,
|
||||
):
|
||||
"""openai_credentials only needed for `dalle` provider."""
|
||||
ConfigurableComponent.__init__(self, config)
|
||||
self.openai_credentials = openai_credentials
|
||||
self._enabled = bool(self.config.image_provider)
|
||||
self._disabled_reason = "No image provider set."
|
||||
self.workspace = workspace
|
||||
self.legacy_config = config
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
if (
|
||||
self.legacy_config.openai_credentials
|
||||
or self.legacy_config.huggingface_api_token
|
||||
or self.legacy_config.sd_webui_auth
|
||||
self.openai_credentials
|
||||
or self.config.huggingface_api_token
|
||||
or self.config.sd_webui_auth
|
||||
):
|
||||
yield self.generate_image
|
||||
|
||||
@@ -48,7 +72,7 @@ class ImageGeneratorComponent(CommandProvider):
|
||||
),
|
||||
"size": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The size of the image",
|
||||
description="The size of the image [256, 512, 1024]",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
@@ -65,22 +89,21 @@ class ImageGeneratorComponent(CommandProvider):
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = self.workspace.root / f"{str(uuid.uuid4())}.jpg"
|
||||
cfg = self.legacy_config
|
||||
|
||||
if cfg.openai_credentials and (
|
||||
cfg.image_provider == "dalle"
|
||||
or not (cfg.huggingface_api_token or cfg.sd_webui_url)
|
||||
if self.openai_credentials and (
|
||||
self.config.image_provider == "dalle"
|
||||
or not (self.config.huggingface_api_token or self.config.sd_webui_url)
|
||||
):
|
||||
return self.generate_image_with_dalle(prompt, filename, size)
|
||||
|
||||
elif cfg.huggingface_api_token and (
|
||||
cfg.image_provider == "huggingface"
|
||||
or not (cfg.openai_credentials or cfg.sd_webui_url)
|
||||
elif self.config.huggingface_api_token and (
|
||||
self.config.image_provider == "huggingface"
|
||||
or not (self.openai_credentials or self.config.sd_webui_url)
|
||||
):
|
||||
return self.generate_image_with_hf(prompt, filename)
|
||||
|
||||
elif cfg.sd_webui_url and (
|
||||
cfg.image_provider == "sdwebui" or cfg.sd_webui_auth
|
||||
elif self.config.sd_webui_url and (
|
||||
self.config.image_provider == "sdwebui" or self.config.sd_webui_auth
|
||||
):
|
||||
return self.generate_image_with_sd_webui(prompt, filename, size)
|
||||
|
||||
@@ -96,13 +119,15 @@ class ImageGeneratorComponent(CommandProvider):
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
API_URL = f"https://api-inference.huggingface.co/models/{self.legacy_config.huggingface_image_model}" # noqa: E501
|
||||
if self.legacy_config.huggingface_api_token is None:
|
||||
API_URL = f"https://api-inference.huggingface.co/models/{self.config.huggingface_image_model}" # noqa: E501
|
||||
if self.config.huggingface_api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.legacy_config.huggingface_api_token}",
|
||||
"Authorization": (
|
||||
f"Bearer {self.config.huggingface_api_token.get_secret_value()}"
|
||||
),
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
@@ -156,7 +181,7 @@ class ImageGeneratorComponent(CommandProvider):
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
assert self.legacy_config.openai_credentials # otherwise this tool is disabled
|
||||
assert self.openai_credentials # otherwise this tool is disabled
|
||||
|
||||
# Check for supported image sizes
|
||||
if size not in [256, 512, 1024]:
|
||||
@@ -169,7 +194,10 @@ class ImageGeneratorComponent(CommandProvider):
|
||||
|
||||
# TODO: integrate in `forge.llm.providers`(?)
|
||||
response = OpenAI(
|
||||
api_key=self.legacy_config.openai_credentials.api_key.get_secret_value()
|
||||
api_key=self.openai_credentials.api_key.get_secret_value(),
|
||||
organization=self.openai_credentials.organization.get_secret_value()
|
||||
if self.openai_credentials.organization
|
||||
else None,
|
||||
).images.generate(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
@@ -208,13 +236,13 @@ class ImageGeneratorComponent(CommandProvider):
|
||||
"""
|
||||
# Create a session and set the basic auth if needed
|
||||
s = requests.Session()
|
||||
if self.legacy_config.sd_webui_auth:
|
||||
username, password = self.legacy_config.sd_webui_auth.split(":")
|
||||
if self.config.sd_webui_auth:
|
||||
username, password = self.config.sd_webui_auth.get_secret_value().split(":")
|
||||
s.auth = (username, password or "")
|
||||
|
||||
# Generate the images
|
||||
response = requests.post(
|
||||
f"{self.legacy_config.sd_webui_url}/sdapi/v1/txt2img",
|
||||
f"{self.config.sd_webui_url}/sdapi/v1/txt2img",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
|
||||
@@ -4,7 +4,6 @@ import click
|
||||
|
||||
from forge.agent.protocols import CommandProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.const import ASK_COMMAND
|
||||
|
||||
@@ -12,9 +11,6 @@ from forge.utils.const import ASK_COMMAND
|
||||
class UserInteractionComponent(CommandProvider):
|
||||
"""Provides commands to interact with the user."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self._enabled = not config.noninteractive_mode
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.ask_user
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class WatchdogComponent(AfterParse[AnyProposal]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: "BaseAgentConfiguration",
|
||||
config: BaseAgentConfiguration,
|
||||
event_history: EpisodicActionHistory[AnyProposal],
|
||||
):
|
||||
self.config = config
|
||||
|
||||
@@ -1,30 +1,44 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from forge.agent.components import ConfigurableComponent
|
||||
from forge.agent.protocols import CommandProvider, DirectiveProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.models.config import UserConfigurable
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import ConfigurationError
|
||||
|
||||
DUCKDUCKGO_MAX_ATTEMPTS = 3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSearchComponent(DirectiveProvider, CommandProvider):
|
||||
class WebSearchConfiguration(BaseModel):
|
||||
google_api_key: Optional[SecretStr] = UserConfigurable(
|
||||
from_env="GOOGLE_API_KEY", exclude=True
|
||||
)
|
||||
google_custom_search_engine_id: Optional[SecretStr] = UserConfigurable(
|
||||
from_env="GOOGLE_CUSTOM_SEARCH_ENGINE_ID", exclude=True
|
||||
)
|
||||
duckduckgo_max_attempts: int = 3
|
||||
|
||||
|
||||
class WebSearchComponent(
|
||||
DirectiveProvider, CommandProvider, ConfigurableComponent[WebSearchConfiguration]
|
||||
):
|
||||
"""Provides commands to search the web."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.legacy_config = config
|
||||
config_class = WebSearchConfiguration
|
||||
|
||||
def __init__(self, config: Optional[WebSearchConfiguration] = None):
|
||||
ConfigurableComponent.__init__(self, config)
|
||||
|
||||
if (
|
||||
not self.legacy_config.google_api_key
|
||||
or not self.legacy_config.google_custom_search_engine_id
|
||||
not self.config.google_api_key
|
||||
or not self.config.google_custom_search_engine_id
|
||||
):
|
||||
logger.info(
|
||||
"Configure google_api_key and custom_search_engine_id "
|
||||
@@ -37,10 +51,7 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.web_search
|
||||
|
||||
if (
|
||||
self.legacy_config.google_api_key
|
||||
and self.legacy_config.google_custom_search_engine_id
|
||||
):
|
||||
if self.config.google_api_key and self.config.google_custom_search_engine_id:
|
||||
yield self.google
|
||||
|
||||
@command(
|
||||
@@ -74,7 +85,7 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
|
||||
search_results = []
|
||||
attempts = 0
|
||||
|
||||
while attempts < DUCKDUCKGO_MAX_ATTEMPTS:
|
||||
while attempts < self.config.duckduckgo_max_attempts:
|
||||
if not query:
|
||||
return json.dumps(search_results)
|
||||
|
||||
@@ -136,17 +147,25 @@ class WebSearchComponent(DirectiveProvider, CommandProvider):
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
try:
|
||||
# Get the Google API key and Custom Search Engine ID from the config file
|
||||
api_key = self.legacy_config.google_api_key
|
||||
custom_search_engine_id = self.legacy_config.google_custom_search_engine_id
|
||||
# Should be the case if this command is enabled:
|
||||
assert self.config.google_api_key
|
||||
assert self.config.google_custom_search_engine_id
|
||||
|
||||
# Initialize the Custom Search API service
|
||||
service = build("customsearch", "v1", developerKey=api_key)
|
||||
service = build(
|
||||
"customsearch",
|
||||
"v1",
|
||||
developerKey=self.config.google_api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
# Send the search query and retrieve the results
|
||||
result = (
|
||||
service.cse()
|
||||
.list(q=query, cx=custom_search_engine_id, num=num_results)
|
||||
.list(
|
||||
q=query,
|
||||
cx=self.config.google_custom_search_engine_id.get_secret_value(),
|
||||
num=num_results,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
@@ -3,10 +3,11 @@ import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from sys import platform
|
||||
from typing import Iterator, Type
|
||||
from typing import Iterator, Literal, Optional, Type
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from pydantic import BaseModel
|
||||
from selenium.common.exceptions import WebDriverException
|
||||
from selenium.webdriver.chrome.options import Options as ChromeOptions
|
||||
from selenium.webdriver.chrome.service import Service as ChromeDriverService
|
||||
@@ -27,12 +28,14 @@ from webdriver_manager.chrome import ChromeDriverManager
|
||||
from webdriver_manager.firefox import GeckoDriverManager
|
||||
from webdriver_manager.microsoft import EdgeChromiumDriverManager as EdgeDriverManager
|
||||
|
||||
from forge.agent.components import ConfigurableComponent
|
||||
from forge.agent.protocols import CommandProvider, DirectiveProvider
|
||||
from forge.command import Command, command
|
||||
from forge.config.config import Config
|
||||
from forge.content_processing.html import extract_hyperlinks, format_hyperlinks
|
||||
from forge.content_processing.text import extract_information, summarize_text
|
||||
from forge.llm.providers import ChatModelInfo, MultiProvider
|
||||
from forge.llm.providers import MultiProvider
|
||||
from forge.llm.providers.multi import ModelName
|
||||
from forge.llm.providers.openai import OpenAIModelName
|
||||
from forge.models.json_schema import JSONSchema
|
||||
from forge.utils.exceptions import CommandExecutionError, TooMuchOutputError
|
||||
from forge.utils.url_validator import validate_url
|
||||
@@ -51,18 +54,38 @@ class BrowsingError(CommandExecutionError):
|
||||
"""An error occurred while trying to browse the page"""
|
||||
|
||||
|
||||
class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
class WebSeleniumConfiguration(BaseModel):
|
||||
model_name: ModelName = OpenAIModelName.GPT3
|
||||
"""Name of the llm model used to read websites"""
|
||||
web_browser: Literal["chrome", "firefox", "safari", "edge"] = "chrome"
|
||||
"""Web browser used by Selenium"""
|
||||
headless: bool = True
|
||||
"""Run browser in headless mode"""
|
||||
user_agent: str = (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"
|
||||
)
|
||||
"""User agent used by the browser"""
|
||||
browse_spacy_language_model: str = "en_core_web_sm"
|
||||
"""Spacy language model used for chunking text"""
|
||||
|
||||
|
||||
class WebSeleniumComponent(
|
||||
DirectiveProvider, CommandProvider, ConfigurableComponent[WebSeleniumConfiguration]
|
||||
):
|
||||
"""Provides commands to browse the web using Selenium."""
|
||||
|
||||
config_class = WebSeleniumConfiguration
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
llm_provider: MultiProvider,
|
||||
model_info: ChatModelInfo,
|
||||
data_dir: Path,
|
||||
config: Optional[WebSeleniumConfiguration] = None,
|
||||
):
|
||||
self.legacy_config = config
|
||||
ConfigurableComponent.__init__(self, config)
|
||||
self.llm_provider = llm_provider
|
||||
self.model_info = model_info
|
||||
self.data_dir = data_dir
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield "Ability to read websites."
|
||||
@@ -129,7 +152,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
"""
|
||||
driver = None
|
||||
try:
|
||||
driver = await self.open_page_in_browser(url, self.legacy_config)
|
||||
driver = await self.open_page_in_browser(url)
|
||||
|
||||
text = self.scrape_text_with_selenium(driver)
|
||||
links = self.scrape_links_with_selenium(driver, url)
|
||||
@@ -141,7 +164,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
elif get_raw_content:
|
||||
if (
|
||||
output_tokens := self.llm_provider.count_tokens(
|
||||
text, self.model_info.name
|
||||
text, self.config.model_name
|
||||
)
|
||||
) > MAX_RAW_CONTENT_LENGTH:
|
||||
oversize_factor = round(output_tokens / MAX_RAW_CONTENT_LENGTH, 1)
|
||||
@@ -228,7 +251,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
|
||||
return format_hyperlinks(hyperlinks)
|
||||
|
||||
async def open_page_in_browser(self, url: str, config: Config) -> WebDriver:
|
||||
async def open_page_in_browser(self, url: str) -> WebDriver:
|
||||
"""Open a browser window and load a web page using Selenium
|
||||
|
||||
Params:
|
||||
@@ -248,11 +271,11 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
"safari": SafariOptions,
|
||||
}
|
||||
|
||||
options: BrowserOptions = options_available[config.selenium_web_browser]()
|
||||
options.add_argument(f"user-agent={config.user_agent}")
|
||||
options: BrowserOptions = options_available[self.config.web_browser]()
|
||||
options.add_argument(f"user-agent={self.config.user_agent}")
|
||||
|
||||
if isinstance(options, FirefoxOptions):
|
||||
if config.selenium_headless:
|
||||
if self.config.headless:
|
||||
options.headless = True # type: ignore
|
||||
options.add_argument("--disable-gpu")
|
||||
driver = FirefoxDriver(
|
||||
@@ -274,13 +297,11 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
options.add_argument("--remote-debugging-port=9222")
|
||||
|
||||
options.add_argument("--no-sandbox")
|
||||
if config.selenium_headless:
|
||||
if self.config.headless:
|
||||
options.add_argument("--headless=new")
|
||||
options.add_argument("--disable-gpu")
|
||||
|
||||
self._sideload_chrome_extensions(
|
||||
options, config.app_data_dir / "assets" / "crx"
|
||||
)
|
||||
self._sideload_chrome_extensions(options, self.data_dir / "assets" / "crx")
|
||||
|
||||
if (chromium_driver_path := Path("/usr/bin/chromedriver")).exists():
|
||||
chrome_service = ChromeDriverService(str(chromium_driver_path))
|
||||
@@ -361,7 +382,8 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
text,
|
||||
topics_of_interest=topics_of_interest,
|
||||
llm_provider=self.llm_provider,
|
||||
config=self.legacy_config,
|
||||
model_name=self.config.model_name,
|
||||
spacy_model=self.config.browse_spacy_language_model,
|
||||
)
|
||||
return "\n".join(f"* {i}" for i in information)
|
||||
else:
|
||||
@@ -369,6 +391,7 @@ class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
text,
|
||||
question=question,
|
||||
llm_provider=self.llm_provider,
|
||||
config=self.legacy_config,
|
||||
model_name=self.config.model_name,
|
||||
spacy_model=self.config.browse_spacy_language_model,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -3,12 +3,10 @@ This module contains configuration models and helpers for AutoGPT Forge.
|
||||
"""
|
||||
from .ai_directives import AIDirectives
|
||||
from .ai_profile import AIProfile
|
||||
from .config import Config, ConfigBuilder, assert_config_has_required_llm_api_keys
|
||||
from .base import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"assert_config_has_required_llm_api_keys",
|
||||
"AIProfile",
|
||||
"AIDirectives",
|
||||
"Config",
|
||||
"ConfigBuilder",
|
||||
"BaseConfig",
|
||||
]
|
||||
|
||||
16
forge/forge/config/base.py
Normal file
16
forge/forge/config/base.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from forge.file_storage import FileStorageBackendName
|
||||
from forge.models.config import SystemSettings, UserConfigurable
|
||||
from forge.speech.say import TTSConfig
|
||||
|
||||
|
||||
class BaseConfig(SystemSettings):
|
||||
name: str = "Base configuration"
|
||||
description: str = "Default configuration for forge agent."
|
||||
|
||||
# TTS configuration
|
||||
tts_config: TTSConfig = TTSConfig()
|
||||
|
||||
# File storage
|
||||
file_storage_backend: FileStorageBackendName = UserConfigurable(
|
||||
default=FileStorageBackendName.LOCAL, from_env="FILE_STORAGE_BACKEND"
|
||||
)
|
||||
@@ -1,302 +0,0 @@
|
||||
"""Configuration class to store the state of bools for different scripts access."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import SecretStr, validator
|
||||
|
||||
import forge
|
||||
from forge.file_storage import FileStorageBackendName
|
||||
from forge.llm.providers import CHAT_MODELS, ModelName
|
||||
from forge.llm.providers.openai import OpenAICredentials, OpenAIModelName
|
||||
from forge.logging.config import LoggingConfig
|
||||
from forge.models.config import Configurable, SystemSettings, UserConfigurable
|
||||
from forge.speech.say import TTSConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROJECT_ROOT = Path(forge.__file__).parent.parent
|
||||
AZURE_CONFIG_FILE = Path("azure.yaml")
|
||||
|
||||
GPT_4_MODEL = OpenAIModelName.GPT4
|
||||
GPT_3_MODEL = OpenAIModelName.GPT3
|
||||
|
||||
|
||||
class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
name: str = "Auto-GPT configuration"
|
||||
description: str = "Default configuration for the Auto-GPT application."
|
||||
|
||||
########################
|
||||
# Application Settings #
|
||||
########################
|
||||
project_root: Path = PROJECT_ROOT
|
||||
app_data_dir: Path = project_root / "data"
|
||||
skip_news: bool = False
|
||||
skip_reprompt: bool = False
|
||||
authorise_key: str = UserConfigurable(default="y", from_env="AUTHORISE_COMMAND_KEY")
|
||||
exit_key: str = UserConfigurable(default="n", from_env="EXIT_KEY")
|
||||
noninteractive_mode: bool = False
|
||||
|
||||
# TTS configuration
|
||||
logging: LoggingConfig = LoggingConfig()
|
||||
tts_config: TTSConfig = TTSConfig()
|
||||
|
||||
# File storage
|
||||
file_storage_backend: FileStorageBackendName = UserConfigurable(
|
||||
default=FileStorageBackendName.LOCAL, from_env="FILE_STORAGE_BACKEND"
|
||||
)
|
||||
|
||||
##########################
|
||||
# Agent Control Settings #
|
||||
##########################
|
||||
# Model configuration
|
||||
fast_llm: ModelName = UserConfigurable(
|
||||
default=OpenAIModelName.GPT3,
|
||||
from_env="FAST_LLM",
|
||||
)
|
||||
smart_llm: ModelName = UserConfigurable(
|
||||
default=OpenAIModelName.GPT4_TURBO,
|
||||
from_env="SMART_LLM",
|
||||
)
|
||||
temperature: float = UserConfigurable(default=0, from_env="TEMPERATURE")
|
||||
openai_functions: bool = UserConfigurable(
|
||||
default=False, from_env=lambda: os.getenv("OPENAI_FUNCTIONS", "False") == "True"
|
||||
)
|
||||
embedding_model: str = UserConfigurable(
|
||||
default="text-embedding-3-small", from_env="EMBEDDING_MODEL"
|
||||
)
|
||||
browse_spacy_language_model: str = UserConfigurable(
|
||||
default="en_core_web_sm", from_env="BROWSE_SPACY_LANGUAGE_MODEL"
|
||||
)
|
||||
|
||||
# Run loop configuration
|
||||
continuous_mode: bool = False
|
||||
continuous_limit: int = 0
|
||||
|
||||
############
|
||||
# Commands #
|
||||
############
|
||||
# General
|
||||
disabled_commands: list[str] = UserConfigurable(
|
||||
default_factory=list,
|
||||
from_env=lambda: _safe_split(os.getenv("DISABLED_COMMANDS")),
|
||||
)
|
||||
|
||||
# File ops
|
||||
restrict_to_workspace: bool = UserConfigurable(
|
||||
default=True,
|
||||
from_env=lambda: os.getenv("RESTRICT_TO_WORKSPACE", "True") == "True",
|
||||
)
|
||||
allow_downloads: bool = False
|
||||
|
||||
# Shell commands
|
||||
shell_command_control: str = UserConfigurable(
|
||||
default="denylist", from_env="SHELL_COMMAND_CONTROL"
|
||||
)
|
||||
execute_local_commands: bool = UserConfigurable(
|
||||
default=False,
|
||||
from_env=lambda: os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True",
|
||||
)
|
||||
shell_denylist: list[str] = UserConfigurable(
|
||||
default_factory=lambda: ["sudo", "su"],
|
||||
from_env=lambda: _safe_split(
|
||||
os.getenv("SHELL_DENYLIST", os.getenv("DENY_COMMANDS"))
|
||||
),
|
||||
)
|
||||
shell_allowlist: list[str] = UserConfigurable(
|
||||
default_factory=list,
|
||||
from_env=lambda: _safe_split(
|
||||
os.getenv("SHELL_ALLOWLIST", os.getenv("ALLOW_COMMANDS"))
|
||||
),
|
||||
)
|
||||
|
||||
# Text to image
|
||||
image_provider: Optional[str] = UserConfigurable(from_env="IMAGE_PROVIDER")
|
||||
huggingface_image_model: str = UserConfigurable(
|
||||
default="CompVis/stable-diffusion-v1-4", from_env="HUGGINGFACE_IMAGE_MODEL"
|
||||
)
|
||||
sd_webui_url: Optional[str] = UserConfigurable(
|
||||
default="http://localhost:7860", from_env="SD_WEBUI_URL"
|
||||
)
|
||||
image_size: int = UserConfigurable(default=256, from_env="IMAGE_SIZE")
|
||||
|
||||
# Audio to text
|
||||
audio_to_text_provider: str = UserConfigurable(
|
||||
default="huggingface", from_env="AUDIO_TO_TEXT_PROVIDER"
|
||||
)
|
||||
huggingface_audio_to_text_model: Optional[str] = UserConfigurable(
|
||||
from_env="HUGGINGFACE_AUDIO_TO_TEXT_MODEL"
|
||||
)
|
||||
|
||||
# Web browsing
|
||||
selenium_web_browser: str = UserConfigurable("chrome", from_env="USE_WEB_BROWSER")
|
||||
selenium_headless: bool = UserConfigurable(
|
||||
default=True, from_env=lambda: os.getenv("HEADLESS_BROWSER", "True") == "True"
|
||||
)
|
||||
user_agent: str = UserConfigurable(
|
||||
default="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36", # noqa: E501
|
||||
from_env="USER_AGENT",
|
||||
)
|
||||
|
||||
###############
|
||||
# Credentials #
|
||||
###############
|
||||
# OpenAI
|
||||
openai_credentials: Optional[OpenAICredentials] = None
|
||||
azure_config_file: Optional[Path] = UserConfigurable(
|
||||
default=AZURE_CONFIG_FILE, from_env="AZURE_CONFIG_FILE"
|
||||
)
|
||||
|
||||
# Github
|
||||
github_api_key: Optional[str] = UserConfigurable(from_env="GITHUB_API_KEY")
|
||||
github_username: Optional[str] = UserConfigurable(from_env="GITHUB_USERNAME")
|
||||
|
||||
# Google
|
||||
google_api_key: Optional[str] = UserConfigurable(from_env="GOOGLE_API_KEY")
|
||||
google_custom_search_engine_id: Optional[str] = UserConfigurable(
|
||||
from_env="GOOGLE_CUSTOM_SEARCH_ENGINE_ID",
|
||||
)
|
||||
|
||||
# Huggingface
|
||||
huggingface_api_token: Optional[str] = UserConfigurable(
|
||||
from_env="HUGGINGFACE_API_TOKEN"
|
||||
)
|
||||
|
||||
# Stable Diffusion
|
||||
sd_webui_auth: Optional[str] = UserConfigurable(from_env="SD_WEBUI_AUTH")
|
||||
|
||||
@validator("openai_functions")
|
||||
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
|
||||
if v:
|
||||
smart_llm = values["smart_llm"]
|
||||
assert CHAT_MODELS[smart_llm].has_function_call_api, (
|
||||
f"Model {smart_llm} does not support tool calling. "
|
||||
"Please disable OPENAI_FUNCTIONS or choose a suitable model."
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class ConfigBuilder(Configurable[Config]):
|
||||
default_settings = Config()
|
||||
|
||||
@classmethod
|
||||
def build_config_from_env(cls, project_root: Path = PROJECT_ROOT) -> Config:
|
||||
"""Initialize the Config class"""
|
||||
|
||||
config = cls.build_agent_configuration()
|
||||
config.project_root = project_root
|
||||
|
||||
# Make relative paths absolute
|
||||
for k in {
|
||||
"azure_config_file", # TODO: move from project root
|
||||
}:
|
||||
setattr(config, k, project_root / getattr(config, k))
|
||||
|
||||
if (
|
||||
config.openai_credentials
|
||||
and config.openai_credentials.api_type == SecretStr("azure")
|
||||
and (config_file := config.azure_config_file)
|
||||
):
|
||||
config.openai_credentials.load_azure_config(config_file)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
async def assert_config_has_required_llm_api_keys(config: Config) -> None:
|
||||
"""
|
||||
Check if API keys (if required) are set for the configured SMART_LLM and FAST_LLM.
|
||||
"""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from forge.llm.providers.anthropic import AnthropicModelName
|
||||
from forge.llm.providers.groq import GroqModelName
|
||||
|
||||
if set((config.smart_llm, config.fast_llm)).intersection(AnthropicModelName):
|
||||
from forge.llm.providers.anthropic import AnthropicCredentials
|
||||
|
||||
try:
|
||||
credentials = AnthropicCredentials.from_env()
|
||||
except ValidationError as e:
|
||||
if "api_key" in str(e):
|
||||
logger.error(
|
||||
"Set your Anthropic API key in .env or as an environment variable"
|
||||
)
|
||||
logger.info(
|
||||
"For further instructions: "
|
||||
"https://docs.agpt.co/autogpt/setup/#anthropic"
|
||||
)
|
||||
|
||||
raise ValueError("Anthropic is unavailable: can't load credentials") from e
|
||||
|
||||
key_pattern = r"^sk-ant-api03-[\w\-]{95}"
|
||||
|
||||
# If key is set, but it looks invalid
|
||||
if not re.search(key_pattern, credentials.api_key.get_secret_value()):
|
||||
logger.warning(
|
||||
"Possibly invalid Anthropic API key! "
|
||||
f"Configured Anthropic API key does not match pattern '{key_pattern}'. "
|
||||
"If this is a valid key, please report this warning to the maintainers."
|
||||
)
|
||||
|
||||
if set((config.smart_llm, config.fast_llm)).intersection(GroqModelName):
|
||||
from groq import AuthenticationError
|
||||
|
||||
from forge.llm.providers.groq import GroqProvider
|
||||
|
||||
try:
|
||||
groq = GroqProvider()
|
||||
await groq.get_available_models()
|
||||
except ValidationError as e:
|
||||
if "api_key" not in str(e):
|
||||
raise
|
||||
|
||||
logger.error("Set your Groq API key in .env or as an environment variable")
|
||||
logger.info(
|
||||
"For further instructions: https://docs.agpt.co/autogpt/setup/#groq"
|
||||
)
|
||||
raise ValueError("Groq is unavailable: can't load credentials")
|
||||
except AuthenticationError as e:
|
||||
logger.error("The Groq API key is invalid!")
|
||||
logger.info(
|
||||
"For instructions to get and set a new API key: "
|
||||
"https://docs.agpt.co/autogpt/setup/#groq"
|
||||
)
|
||||
raise ValueError("Groq is unavailable: invalid API key") from e
|
||||
|
||||
if set((config.smart_llm, config.fast_llm)).intersection(OpenAIModelName):
|
||||
from openai import AuthenticationError
|
||||
|
||||
from forge.llm.providers.openai import OpenAIProvider
|
||||
|
||||
try:
|
||||
openai = OpenAIProvider()
|
||||
await openai.get_available_models()
|
||||
except ValidationError as e:
|
||||
if "api_key" not in str(e):
|
||||
raise
|
||||
|
||||
logger.error(
|
||||
"Set your OpenAI API key in .env or as an environment variable"
|
||||
)
|
||||
logger.info(
|
||||
"For further instructions: https://docs.agpt.co/autogpt/setup/#openai"
|
||||
)
|
||||
raise ValueError("OpenAI is unavailable: can't load credentials")
|
||||
except AuthenticationError as e:
|
||||
logger.error("The OpenAI API key is invalid!")
|
||||
logger.info(
|
||||
"For instructions to get and set a new API key: "
|
||||
"https://docs.agpt.co/autogpt/setup/#openai"
|
||||
)
|
||||
raise ValueError("OpenAI is unavailable: invalid API key") from e
|
||||
|
||||
|
||||
def _safe_split(s: Union[str, None], sep: str = ",") -> list[str]:
|
||||
"""Split a string by a separator. Return an empty list if the string is None."""
|
||||
if s is None:
|
||||
return []
|
||||
return s.split(sep)
|
||||
@@ -3,16 +3,14 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Iterator, Optional, TypeVar
|
||||
from typing import Iterator, Optional, TypeVar
|
||||
|
||||
import spacy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from forge.config.config import Config
|
||||
|
||||
from forge.json.parsing import extract_list_from_json
|
||||
from forge.llm.prompting import ChatPrompt
|
||||
from forge.llm.providers import ChatMessage, ModelTokenizer, MultiProvider
|
||||
from forge.llm.providers.multi import ModelName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,7 +55,8 @@ def chunk_content(
|
||||
async def summarize_text(
|
||||
text: str,
|
||||
llm_provider: MultiProvider,
|
||||
config: Config,
|
||||
model_name: ModelName,
|
||||
spacy_model: str = "en_core_web_sm",
|
||||
question: Optional[str] = None,
|
||||
instruction: Optional[str] = None,
|
||||
) -> tuple[str, list[tuple[str, str]]]:
|
||||
@@ -82,7 +81,8 @@ async def summarize_text(
|
||||
text=text,
|
||||
instruction=instruction,
|
||||
llm_provider=llm_provider,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
spacy_model=spacy_model,
|
||||
)
|
||||
|
||||
|
||||
@@ -90,7 +90,8 @@ async def extract_information(
|
||||
source_text: str,
|
||||
topics_of_interest: list[str],
|
||||
llm_provider: MultiProvider,
|
||||
config: Config,
|
||||
model_name: ModelName,
|
||||
spacy_model: str = "en_core_web_sm",
|
||||
) -> list[str]:
|
||||
fmt_topics_list = "\n".join(f"* {topic}." for topic in topics_of_interest)
|
||||
instruction = (
|
||||
@@ -106,7 +107,8 @@ async def extract_information(
|
||||
instruction=instruction,
|
||||
output_type=list[str],
|
||||
llm_provider=llm_provider,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
spacy_model=spacy_model,
|
||||
)
|
||||
|
||||
|
||||
@@ -114,7 +116,8 @@ async def _process_text(
|
||||
text: str,
|
||||
instruction: str,
|
||||
llm_provider: MultiProvider,
|
||||
config: Config,
|
||||
model_name: ModelName,
|
||||
spacy_model: str = "en_core_web_sm",
|
||||
output_type: type[str | list[str]] = str,
|
||||
) -> tuple[str, list[tuple[str, str]]] | list[str]:
|
||||
"""Process text using the OpenAI API for summarization or information extraction
|
||||
@@ -123,7 +126,8 @@ async def _process_text(
|
||||
text (str): The text to process.
|
||||
instruction (str): Additional instruction for processing.
|
||||
llm_provider: LLM provider to use.
|
||||
config (Config): The global application config.
|
||||
model_name: The name of the llm model to use.
|
||||
spacy_model: The spaCy model to use for sentence splitting.
|
||||
output_type: `str` for summaries or `list[str]` for piece-wise info extraction.
|
||||
|
||||
Returns:
|
||||
@@ -133,13 +137,11 @@ async def _process_text(
|
||||
if not text.strip():
|
||||
raise ValueError("No content")
|
||||
|
||||
model = config.fast_llm
|
||||
|
||||
text_tlength = llm_provider.count_tokens(text, model)
|
||||
text_tlength = llm_provider.count_tokens(text, model_name)
|
||||
logger.debug(f"Text length: {text_tlength} tokens")
|
||||
|
||||
max_result_tokens = 500
|
||||
max_chunk_length = llm_provider.get_token_limit(model) - max_result_tokens - 50
|
||||
max_chunk_length = llm_provider.get_token_limit(model_name) - max_result_tokens - 50
|
||||
logger.debug(f"Max chunk length: {max_chunk_length} tokens")
|
||||
|
||||
if text_tlength < max_chunk_length:
|
||||
@@ -157,7 +159,7 @@ async def _process_text(
|
||||
|
||||
response = await llm_provider.create_chat_completion(
|
||||
model_prompt=prompt.messages,
|
||||
model_name=model,
|
||||
model_name=model_name,
|
||||
temperature=0.5,
|
||||
max_output_tokens=max_result_tokens,
|
||||
completion_parser=lambda s: (
|
||||
@@ -182,9 +184,9 @@ async def _process_text(
|
||||
chunks = list(
|
||||
split_text(
|
||||
text,
|
||||
config=config,
|
||||
max_chunk_length=max_chunk_length,
|
||||
tokenizer=llm_provider.get_tokenizer(model),
|
||||
tokenizer=llm_provider.get_tokenizer(model_name),
|
||||
spacy_model=spacy_model,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -196,7 +198,8 @@ async def _process_text(
|
||||
instruction=instruction,
|
||||
output_type=output_type,
|
||||
llm_provider=llm_provider,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
spacy_model=spacy_model,
|
||||
)
|
||||
processed_results.extend(
|
||||
chunk_result if output_type == list[str] else [chunk_result]
|
||||
@@ -212,7 +215,8 @@ async def _process_text(
|
||||
"Combine these partial summaries into one."
|
||||
),
|
||||
llm_provider=llm_provider,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
spacy_model=spacy_model,
|
||||
)
|
||||
return summary.strip(), [
|
||||
(processed_results[i], chunks[i][0]) for i in range(0, len(chunks))
|
||||
@@ -221,9 +225,9 @@ async def _process_text(
|
||||
|
||||
def split_text(
|
||||
text: str,
|
||||
config: Config,
|
||||
max_chunk_length: int,
|
||||
tokenizer: ModelTokenizer,
|
||||
spacy_model: str = "en_core_web_sm",
|
||||
with_overlap: bool = True,
|
||||
) -> Iterator[tuple[str, int]]:
|
||||
"""
|
||||
@@ -231,7 +235,7 @@ def split_text(
|
||||
|
||||
Args:
|
||||
text (str): The text to split.
|
||||
config (Config): Config object containing the Spacy model setting.
|
||||
spacy_model (str): The spaCy model to use for sentence splitting.
|
||||
max_chunk_length (int, optional): The maximum length of a chunk.
|
||||
tokenizer (ModelTokenizer): Tokenizer to use for determining chunk length.
|
||||
with_overlap (bool, optional): Whether to allow overlap between chunks.
|
||||
@@ -251,7 +255,7 @@ def split_text(
|
||||
n_chunks = math.ceil(text_length / max_chunk_length)
|
||||
target_chunk_length = math.ceil(text_length / n_chunks)
|
||||
|
||||
nlp: spacy.language.Language = spacy.load(config.browse_spacy_language_model)
|
||||
nlp: spacy.language.Language = spacy.load(spacy_model)
|
||||
nlp.add_pipe("sentencizer")
|
||||
doc = nlp(text)
|
||||
sentences = [sentence.text.strip() for sentence in doc.sents]
|
||||
|
||||
@@ -16,6 +16,7 @@ def UserConfigurable(
|
||||
default_factory: Optional[Callable[[], T]] = None,
|
||||
from_env: Optional[str | Callable[[], T | None]] = None,
|
||||
description: str = "",
|
||||
exclude: bool = False,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
# TODO: use this to auto-generate docs for the application configuration
|
||||
@@ -25,6 +26,7 @@ def UserConfigurable(
|
||||
default_factory=default_factory,
|
||||
from_env=from_env,
|
||||
description=description,
|
||||
exclude=exclude,
|
||||
**kwargs,
|
||||
user_configurable=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user