mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-06 22:03:59 -05:00
109 lines
3.4 KiB
Python
109 lines
3.4 KiB
Python
import json
|
|
import os
|
|
from typing import Any, Dict, Generic, Set, Tuple, Type, TypeVar
|
|
from pydantic import BaseModel, Field, PrivateAttr
|
|
from pydantic_settings import (
|
|
BaseSettings,
|
|
JsonConfigSettingsSource,
|
|
PydanticBaseSettingsSource,
|
|
SettingsConfigDict,
|
|
)
|
|
|
|
from autogpt_server.util.data import get_config_path, get_data_path, get_secrets_path
|
|
|
|
T = TypeVar("T", bound=BaseSettings)
|
|
|
|
|
|
class UpdateTrackingModel(BaseModel, Generic[T]):
|
|
_updated_fields: Set[str] = PrivateAttr(default_factory=set)
|
|
|
|
def __setattr__(self, name: str, value) -> None:
|
|
if name in self.model_fields:
|
|
self._updated_fields.add(name)
|
|
super().__setattr__(name, value)
|
|
|
|
|
|
def mark_updated(self, field_name: str) -> None:
|
|
if field_name in self.model_fields:
|
|
self._updated_fields.add(field_name)
|
|
|
|
def clear_updates(self) -> None:
|
|
self._updated_fields.clear()
|
|
|
|
def get_updates(self) -> Dict[str, Any]:
|
|
return {field: getattr(self, field) for field in self._updated_fields}
|
|
|
|
|
|
class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|
"""Config for the server."""
|
|
|
|
num_workers: int = Field(
|
|
default=9, ge=1, le=100, description="Number of workers to use for execution."
|
|
)
|
|
# Add more configuration fields as needed
|
|
|
|
model_config = SettingsConfigDict(
|
|
json_file=[
|
|
get_config_path() / "config.default.json",
|
|
get_config_path() / "config.json",
|
|
],
|
|
env_file=".env",
|
|
env_file_encoding="utf-8",
|
|
extra="allow",
|
|
)
|
|
|
|
@classmethod
|
|
def settings_customise_sources(
|
|
cls,
|
|
settings_cls: Type[BaseSettings],
|
|
init_settings: PydanticBaseSettingsSource,
|
|
env_settings: PydanticBaseSettingsSource,
|
|
dotenv_settings: PydanticBaseSettingsSource,
|
|
file_secret_settings: PydanticBaseSettingsSource,
|
|
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
|
return (JsonConfigSettingsSource(settings_cls),)
|
|
|
|
|
|
class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|
"""Secrets for the server."""
|
|
|
|
database_password: str = ""
|
|
# Add more secret fields as needed
|
|
|
|
model_config = SettingsConfigDict(
|
|
secrets_dir=get_secrets_path(),
|
|
env_file=".env",
|
|
env_file_encoding="utf-8",
|
|
extra="allow",
|
|
)
|
|
|
|
|
|
class Settings(BaseModel):
|
|
config: Config = Config()
|
|
secrets: Secrets = Secrets()
|
|
|
|
def save(self) -> None:
|
|
# Save updated config to JSON file
|
|
if self.config._updated_fields:
|
|
config_to_save = self.config.get_updates()
|
|
config_path = os.path.join(get_data_path(), "config.json")
|
|
if os.path.exists(config_path):
|
|
with open(config_path, "r+") as f:
|
|
existing_config: Dict[str, Any] = json.load(f)
|
|
existing_config.update(config_to_save)
|
|
f.seek(0)
|
|
json.dump(existing_config, f, indent=2)
|
|
f.truncate()
|
|
else:
|
|
with open(config_path, "w") as f:
|
|
json.dump(config_to_save, f, indent=2)
|
|
self.config.clear_updates()
|
|
|
|
# Save updated secrets to individual files
|
|
secrets_dir = get_secrets_path()
|
|
for key in self.secrets._updated_fields:
|
|
secret_file = os.path.join(secrets_dir, key)
|
|
with open(secret_file, "w") as f:
|
|
f.write(str(getattr(self.secrets, key)))
|
|
self.secrets.clear_updates()
|