mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 07:08:09 -05:00
The changes in this PR are to add Llama API support. ### Changes 🏗️ We add both backend and frontend support. **Backend**: - Add llama_api provider - Include models supported by Llama API along with configs - llm_call - credential store and llama_api_key field in Settings **Frontend**: - Llama API as a type - Credentials input and provider for Llama API ### Checklist 📋 #### For code changes: - [X] I have clearly listed my changes in the PR description - [X] I have tested my changes according to the test plan: **Test Plan**: <details> <summary>AI Text Generator</summary> - [X] Start-up backend and frontend: - Start backend with Docker services: `docker compose up -d --build` - Start frontend: `npm install && npm run dev` - By visiting http://localhost:3000/, test inference and structured outputs - [X] Create from scratch - [X] Request for Llama API Credentials <img width="2015" alt="image" src="https://github.com/user-attachments/assets/3dede402-3718-4441-9327-ecab25c63ebf" /> - [X] Execute an agent with at least 3 blocks <img width="2026" alt="image" src="https://github.com/user-attachments/assets/59d6d56b-2ccc-4af5-b511-4af312c3f7f8" /> - [X] Confirm it executes correctly </details> <details> <summary>Structured Response Generator</summary> - [X] Start-up backend and frontend: - Start backend with Docker services: `docker compose up -d --build` - Start frontend: `npm install && npm run dev` - By visiting http://localhost:3000/, test inference and structured outputs - [X] Create from scratch - [X] Execute an agent <img width="2023" alt="image" src="https://github.com/user-attachments/assets/d1107638-bf1b-45b1-a296-1e0fac29525b" /> - [X] Confirm it executes correctly </details> --------- Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
471 lines
16 KiB
Python
471 lines
16 KiB
Python
import json
|
|
import os
|
|
from enum import Enum
|
|
from typing import Any, Dict, Generic, List, Set, Tuple, Type, TypeVar
|
|
|
|
from pydantic import BaseModel, Field, PrivateAttr, ValidationInfo, field_validator
|
|
from pydantic_settings import (
|
|
BaseSettings,
|
|
JsonConfigSettingsSource,
|
|
PydanticBaseSettingsSource,
|
|
SettingsConfigDict,
|
|
)
|
|
|
|
from backend.util.data import get_data_path
|
|
|
|
T = TypeVar("T", bound=BaseSettings)
|
|
|
|
|
|
class AppEnvironment(str, Enum):
|
|
LOCAL = "local"
|
|
DEVELOPMENT = "dev"
|
|
PRODUCTION = "prod"
|
|
|
|
|
|
class BehaveAs(str, Enum):
|
|
LOCAL = "local"
|
|
CLOUD = "cloud"
|
|
|
|
|
|
class UpdateTrackingModel(BaseModel, Generic[T]):
|
|
_updated_fields: Set[str] = PrivateAttr(default_factory=set)
|
|
|
|
def __setattr__(self, name: str, value) -> None:
|
|
if name in UpdateTrackingModel.model_fields:
|
|
self._updated_fields.add(name)
|
|
super().__setattr__(name, value)
|
|
|
|
def mark_updated(self, field_name: str) -> None:
|
|
if field_name in UpdateTrackingModel.model_fields:
|
|
self._updated_fields.add(field_name)
|
|
|
|
def clear_updates(self) -> None:
|
|
self._updated_fields.clear()
|
|
|
|
def get_updates(self) -> Dict[str, Any]:
|
|
return {field: getattr(self, field) for field in self._updated_fields}
|
|
|
|
@property
|
|
def updated_fields(self):
|
|
return self._updated_fields
|
|
|
|
|
|
class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|
"""Config for the server."""
|
|
|
|
num_graph_workers: int = Field(
|
|
default=10,
|
|
ge=1,
|
|
le=1000,
|
|
description="Maximum number of workers to use for graph execution.",
|
|
)
|
|
num_node_workers: int = Field(
|
|
default=5,
|
|
ge=1,
|
|
le=1000,
|
|
description="Maximum number of workers to use for node execution within a single graph.",
|
|
)
|
|
pyro_host: str = Field(
|
|
default="localhost",
|
|
description="The default hostname of the Pyro server.",
|
|
)
|
|
pyro_client_comm_timeout: float = Field(
|
|
default=15,
|
|
description="The default timeout in seconds, for Pyro client connections.",
|
|
)
|
|
pyro_client_comm_retry: int = Field(
|
|
default=3,
|
|
description="The default number of retries for Pyro client connections.",
|
|
)
|
|
rpc_client_call_timeout: int = Field(
|
|
default=300,
|
|
description="The default timeout in seconds, for RPC client calls.",
|
|
)
|
|
enable_auth: bool = Field(
|
|
default=True,
|
|
description="If authentication is enabled or not",
|
|
)
|
|
enable_credit: bool = Field(
|
|
default=False,
|
|
description="If user credit system is enabled or not",
|
|
)
|
|
enable_beta_monthly_credit: bool = Field(
|
|
default=True,
|
|
description="If beta monthly credits accounting is enabled or not",
|
|
)
|
|
num_user_credits_refill: int = Field(
|
|
default=1500,
|
|
description="Number of credits to refill for each user",
|
|
)
|
|
refund_credit_tolerance_threshold: int = Field(
|
|
default=500,
|
|
description="Maximum number of credits above the balance to be auto-approved.",
|
|
)
|
|
refund_notification_email: str = Field(
|
|
default="refund@agpt.co",
|
|
description="Email address to send refund notifications to.",
|
|
)
|
|
refund_request_time_key_format: str = Field(
|
|
default="%Y-%W", # This will allow for weekly refunds per user.
|
|
description="Time key format for refund requests.",
|
|
)
|
|
execution_cost_count_threshold: int = Field(
|
|
default=100,
|
|
description="Number of executions after which the cost is calculated.",
|
|
)
|
|
execution_cost_per_threshold: int = Field(
|
|
default=1,
|
|
description="Cost per execution in cents after each threshold.",
|
|
)
|
|
execution_counter_expiration_time: int = Field(
|
|
default=60 * 60 * 24,
|
|
description="Time in seconds after which the execution counter is reset.",
|
|
)
|
|
execution_late_notification_threshold_secs: int = Field(
|
|
default=5 * 60,
|
|
description="Time in seconds after which the execution stuck on QUEUED status is considered late.",
|
|
)
|
|
execution_late_notification_checkrange_secs: int = Field(
|
|
default=60 * 60,
|
|
description="Time in seconds for how far back to check for the late executions.",
|
|
)
|
|
|
|
model_config = SettingsConfigDict(
|
|
env_file=".env",
|
|
extra="allow",
|
|
)
|
|
|
|
websocket_server_host: str = Field(
|
|
default="0.0.0.0",
|
|
description="The host for the websocket server to run on",
|
|
)
|
|
|
|
websocket_server_port: int = Field(
|
|
default=8001,
|
|
description="The port for the websocket server to run on",
|
|
)
|
|
|
|
execution_manager_port: int = Field(
|
|
default=8002,
|
|
description="The port for execution manager daemon to run on",
|
|
)
|
|
|
|
execution_scheduler_port: int = Field(
|
|
default=8003,
|
|
description="The port for execution scheduler daemon to run on",
|
|
)
|
|
|
|
agent_server_port: int = Field(
|
|
default=8004,
|
|
description="The port for agent server daemon to run on",
|
|
)
|
|
|
|
database_api_port: int = Field(
|
|
default=8005,
|
|
description="The port for database server API to run on",
|
|
)
|
|
|
|
agent_api_host: str = Field(
|
|
default="0.0.0.0",
|
|
description="The host for agent server API to run on",
|
|
)
|
|
|
|
agent_api_port: int = Field(
|
|
default=8006,
|
|
description="The port for agent server API to run on",
|
|
)
|
|
|
|
notification_service_port: int = Field(
|
|
default=8007,
|
|
description="The port for notification service daemon to run on",
|
|
)
|
|
|
|
otto_api_url: str = Field(
|
|
default="",
|
|
description="The URL for the Otto API service",
|
|
)
|
|
|
|
platform_base_url: str = Field(
|
|
default="",
|
|
description="Must be set so the application knows where it's hosted at. "
|
|
"This is necessary to make sure webhooks find their way.",
|
|
)
|
|
|
|
frontend_base_url: str = Field(
|
|
default="",
|
|
description="Can be used to explicitly set the base URL for the frontend. "
|
|
"This value is then used to generate redirect URLs for OAuth flows.",
|
|
)
|
|
|
|
media_gcs_bucket_name: str = Field(
|
|
default="",
|
|
description="The name of the Google Cloud Storage bucket for media files",
|
|
)
|
|
|
|
reddit_user_agent: str = Field(
|
|
default="AutoGPT:1.0 (by /u/autogpt)",
|
|
description="The user agent for the Reddit API",
|
|
)
|
|
|
|
scheduler_db_pool_size: int = Field(
|
|
default=3,
|
|
description="The pool size for the scheduler database connection pool",
|
|
)
|
|
|
|
rabbitmq_host: str = Field(
|
|
default="localhost",
|
|
description="The host for the RabbitMQ server",
|
|
)
|
|
rabbitmq_port: int = Field(
|
|
default=5672,
|
|
description="The port for the RabbitMQ server",
|
|
)
|
|
|
|
rabbitmq_vhost: str = Field(
|
|
default="/",
|
|
description="The vhost for the RabbitMQ server",
|
|
)
|
|
|
|
postmark_sender_email: str = Field(
|
|
default="invalid@invalid.com",
|
|
description="The email address to use for sending emails",
|
|
)
|
|
|
|
use_agent_image_generation_v2: bool = Field(
|
|
default=True,
|
|
description="Whether to use the new agent image generation service",
|
|
)
|
|
enable_agent_input_subtype_blocks: bool = Field(
|
|
default=True,
|
|
description="Whether to enable the agent input subtype blocks",
|
|
)
|
|
platform_alert_discord_channel: str = Field(
|
|
default="local-alerts",
|
|
description="The Discord channel for the platform",
|
|
)
|
|
|
|
@field_validator("platform_base_url", "frontend_base_url")
|
|
@classmethod
|
|
def validate_platform_base_url(cls, v: str, info: ValidationInfo) -> str:
|
|
if not v:
|
|
return v
|
|
if not v.startswith(("http://", "https://")):
|
|
raise ValueError(
|
|
f"{info.field_name} must be a full URL "
|
|
"including a http:// or https:// schema"
|
|
)
|
|
if v.endswith("/"):
|
|
return v.rstrip("/") # Remove trailing slash
|
|
return v
|
|
|
|
app_env: AppEnvironment = Field(
|
|
default=AppEnvironment.LOCAL,
|
|
description="The name of the app environment: local or dev or prod",
|
|
)
|
|
|
|
behave_as: BehaveAs = Field(
|
|
default=BehaveAs.LOCAL,
|
|
description="What environment to behave as: local or cloud",
|
|
)
|
|
|
|
execution_event_bus_name: str = Field(
|
|
default="execution_event",
|
|
description="Name of the event bus",
|
|
)
|
|
|
|
trust_endpoints_for_requests: List[str] = Field(
|
|
default_factory=list,
|
|
description="A whitelist of trusted internal endpoints for the backend to make requests to.",
|
|
)
|
|
|
|
backend_cors_allow_origins: List[str] = Field(default_factory=list)
|
|
|
|
@field_validator("backend_cors_allow_origins")
|
|
@classmethod
|
|
def validate_cors_allow_origins(cls, v: List[str]) -> List[str]:
|
|
out = []
|
|
port = None
|
|
has_localhost = False
|
|
has_127_0_0_1 = False
|
|
for url in v:
|
|
url = url.strip()
|
|
if url.startswith(("http://", "https://")):
|
|
if "localhost" in url:
|
|
port = url.split(":")[2]
|
|
has_localhost = True
|
|
if "127.0.0.1" in url:
|
|
port = url.split(":")[2]
|
|
has_127_0_0_1 = True
|
|
out.append(url)
|
|
else:
|
|
raise ValueError(f"Invalid URL: {url}")
|
|
|
|
if has_127_0_0_1 and not has_localhost:
|
|
out.append(f"http://localhost:{port}")
|
|
if has_localhost and not has_127_0_0_1:
|
|
out.append(f"http://127.0.0.1:{port}")
|
|
|
|
return out
|
|
|
|
@classmethod
|
|
def settings_customise_sources(
|
|
cls,
|
|
settings_cls: Type[BaseSettings],
|
|
init_settings: PydanticBaseSettingsSource,
|
|
env_settings: PydanticBaseSettingsSource,
|
|
dotenv_settings: PydanticBaseSettingsSource,
|
|
file_secret_settings: PydanticBaseSettingsSource,
|
|
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
|
return (
|
|
env_settings,
|
|
file_secret_settings,
|
|
dotenv_settings,
|
|
JsonConfigSettingsSource(settings_cls),
|
|
init_settings,
|
|
)
|
|
|
|
|
|
class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|
"""Secrets for the server."""
|
|
|
|
supabase_url: str = Field(default="", description="Supabase URL")
|
|
supabase_service_role_key: str = Field(
|
|
default="", description="Supabase service role key"
|
|
)
|
|
|
|
encryption_key: str = Field(default="", description="Encryption key")
|
|
|
|
rabbitmq_default_user: str = Field(default="", description="RabbitMQ default user")
|
|
rabbitmq_default_pass: str = Field(
|
|
default="", description="RabbitMQ default password"
|
|
)
|
|
|
|
postmark_server_api_token: str = Field(
|
|
default="", description="Postmark server API token used for sending emails"
|
|
)
|
|
|
|
postmark_webhook_token: str = Field(
|
|
default="",
|
|
description="The token to use for the Postmark webhook",
|
|
)
|
|
|
|
unsubscribe_secret_key: str = Field(
|
|
default="",
|
|
description="The secret key to use for the unsubscribe user by token",
|
|
)
|
|
|
|
# Cloudflare Turnstile credentials
|
|
turnstile_secret_key: str = Field(
|
|
default="",
|
|
description="Cloudflare Turnstile backend secret key",
|
|
)
|
|
turnstile_verify_url: str = Field(
|
|
default="https://challenges.cloudflare.com/turnstile/v0/siteverify",
|
|
description="Cloudflare Turnstile verify URL",
|
|
)
|
|
|
|
# OAuth server credentials for integrations
|
|
# --8<-- [start:OAuthServerCredentialsExample]
|
|
github_client_id: str = Field(default="", description="GitHub OAuth client ID")
|
|
github_client_secret: str = Field(
|
|
default="", description="GitHub OAuth client secret"
|
|
)
|
|
# --8<-- [end:OAuthServerCredentialsExample]
|
|
google_client_id: str = Field(default="", description="Google OAuth client ID")
|
|
google_client_secret: str = Field(
|
|
default="", description="Google OAuth client secret"
|
|
)
|
|
notion_client_id: str = Field(default="", description="Notion OAuth client ID")
|
|
notion_client_secret: str = Field(
|
|
default="", description="Notion OAuth client secret"
|
|
)
|
|
twitter_client_id: str = Field(default="", description="Twitter/X OAuth client ID")
|
|
twitter_client_secret: str = Field(
|
|
default="", description="Twitter/X OAuth client secret"
|
|
)
|
|
|
|
openai_api_key: str = Field(default="", description="OpenAI API key")
|
|
anthropic_api_key: str = Field(default="", description="Anthropic API key")
|
|
groq_api_key: str = Field(default="", description="Groq API key")
|
|
open_router_api_key: str = Field(default="", description="Open Router API Key")
|
|
llama_api_key: str = Field(default="", description="Llama API Key")
|
|
|
|
reddit_client_id: str = Field(default="", description="Reddit client ID")
|
|
reddit_client_secret: str = Field(default="", description="Reddit client secret")
|
|
|
|
openweathermap_api_key: str = Field(
|
|
default="", description="OpenWeatherMap API key"
|
|
)
|
|
|
|
medium_api_key: str = Field(default="", description="Medium API key")
|
|
medium_author_id: str = Field(default="", description="Medium author ID")
|
|
did_api_key: str = Field(default="", description="D-ID API Key")
|
|
revid_api_key: str = Field(default="", description="revid.ai API key")
|
|
discord_bot_token: str = Field(default="", description="Discord bot token")
|
|
|
|
smtp_server: str = Field(default="", description="SMTP server IP")
|
|
smtp_port: str = Field(default="", description="SMTP server port")
|
|
smtp_username: str = Field(default="", description="SMTP username")
|
|
smtp_password: str = Field(default="", description="SMTP password")
|
|
|
|
sentry_dsn: str = Field(default="", description="Sentry DSN")
|
|
|
|
google_maps_api_key: str = Field(default="", description="Google Maps API Key")
|
|
|
|
replicate_api_key: str = Field(default="", description="Replicate API Key")
|
|
unreal_speech_api_key: str = Field(default="", description="Unreal Speech API Key")
|
|
ideogram_api_key: str = Field(default="", description="Ideogram API Key")
|
|
jina_api_key: str = Field(default="", description="Jina API Key")
|
|
unreal_speech_api_key: str = Field(default="", description="Unreal Speech API Key")
|
|
|
|
fal_api_key: str = Field(default="", description="FAL API key")
|
|
exa_api_key: str = Field(default="", description="Exa API key")
|
|
e2b_api_key: str = Field(default="", description="E2B API key")
|
|
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
|
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
|
|
|
linear_client_id: str = Field(default="", description="Linear client ID")
|
|
linear_client_secret: str = Field(default="", description="Linear client secret")
|
|
|
|
todoist_client_id: str = Field(default="", description="Todoist client ID")
|
|
todoist_client_secret: str = Field(default="", description="Todoist client secret")
|
|
|
|
stripe_api_key: str = Field(default="", description="Stripe API Key")
|
|
stripe_webhook_secret: str = Field(default="", description="Stripe Webhook Secret")
|
|
|
|
screenshotone_api_key: str = Field(default="", description="ScreenshotOne API Key")
|
|
|
|
apollo_api_key: str = Field(default="", description="Apollo API Key")
|
|
smartlead_api_key: str = Field(default="", description="SmartLead API Key")
|
|
zerobounce_api_key: str = Field(default="", description="ZeroBounce API Key")
|
|
|
|
# Add more secret fields as needed
|
|
|
|
model_config = SettingsConfigDict(
|
|
env_file=".env",
|
|
env_file_encoding="utf-8",
|
|
extra="allow",
|
|
)
|
|
|
|
|
|
class Settings(BaseModel):
|
|
config: Config = Config()
|
|
secrets: Secrets = Secrets()
|
|
|
|
def save(self) -> None:
|
|
# Save updated config to JSON file
|
|
if self.config.updated_fields:
|
|
config_to_save = self.config.get_updates()
|
|
config_path = os.path.join(get_data_path(), "config.json")
|
|
if os.path.exists(config_path):
|
|
with open(config_path, "r+") as f:
|
|
existing_config: Dict[str, Any] = json.load(f)
|
|
existing_config.update(config_to_save)
|
|
f.seek(0)
|
|
json.dump(existing_config, f, indent=2)
|
|
f.truncate()
|
|
else:
|
|
with open(config_path, "w") as f:
|
|
json.dump(config_to_save, f, indent=2)
|
|
self.config.clear_updates()
|