mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
feat(backend): add the capibility to disable llm models in the cloud env (#8285)
* feat(backend): logic to disable enums based on python logic
* feat(backend): add behave as setting and clarify its purpose and APP_ENV
APP_ENV is used for not cloud vs local but the application environment such as local/dev/prod so we need BehaveAs as well
* fix(backend): various uses of AppEnvironment without the Enum or incorrectly
AppEnv in the logging library will never be cloud due to the restrictions applied when loading settings in by pydantic settings. This commit fixes this error, however the code path for logging may now be incorrect
* feat(backend): use a metaclass to disable ollama in the cloud environment
* fix: formatting
* fix(backend): typing improvements
* fix(backend): more linting 😭
This commit is contained in:
@@ -12,7 +12,10 @@ REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
ENABLE_CREDIT=false
|
||||
APP_ENV="local"
|
||||
# What environment things should be logged under: local dev or prod
|
||||
APP_ENV=local
|
||||
# What environment to behave as: "local" or "cloud"
|
||||
BEHAVE_AS=local
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
@@ -98,5 +101,3 @@ ENABLE_CLOUD_LOGGING=false
|
||||
ENABLE_FILE_LOGGING=false
|
||||
# Use to manually set the log directory
|
||||
# LOG_DIR=./logs
|
||||
|
||||
APP_ENV=local
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import ast
|
||||
import logging
|
||||
from enum import Enum
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, List, NamedTuple
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any, List, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import _EnumMemberT
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
@@ -12,6 +16,7 @@ from groq import Groq
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||
from backend.util import json
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,7 +34,26 @@ class ModelMetadata(NamedTuple):
|
||||
cost_factor: int
|
||||
|
||||
|
||||
class LlmModel(str, Enum):
|
||||
class LlmModelMeta(EnumMeta):
|
||||
@property
|
||||
def __members__(
|
||||
self: type["_EnumMemberT"],
|
||||
) -> MappingProxyType[str, "_EnumMemberT"]:
|
||||
if Settings().config.behave_as == BehaveAs.LOCAL:
|
||||
members = super().__members__
|
||||
return members
|
||||
else:
|
||||
removed_providers = ["ollama"]
|
||||
existing_members = super().__members__
|
||||
members = {
|
||||
name: member
|
||||
for name, member in existing_members.items()
|
||||
if LlmModel[name].provider not in removed_providers
|
||||
}
|
||||
return MappingProxyType(members)
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenAI models
|
||||
O1_PREVIEW = "o1-preview"
|
||||
O1_MINI = "o1-mini"
|
||||
@@ -58,6 +82,18 @@ class LlmModel(str, Enum):
|
||||
def metadata(self) -> ModelMetadata:
|
||||
return MODEL_METADATA[self]
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self.metadata.provider
|
||||
|
||||
@property
|
||||
def context_window(self) -> int:
|
||||
return self.metadata.context_window
|
||||
|
||||
@property
|
||||
def cost_factor(self) -> int:
|
||||
return self.metadata.cost_factor
|
||||
|
||||
|
||||
MODEL_METADATA = {
|
||||
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=60),
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.data.user import get_or_create_user
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
from backend.server.model import CreateGraph, SetGraphActiveVersion
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.settings import Config, Settings
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
from .utils import get_user_id
|
||||
|
||||
@@ -52,7 +52,7 @@ class AgentServer(AppService):
|
||||
await db.disconnect()
|
||||
|
||||
def run_service(self):
|
||||
docs_url = "/docs" if settings.config.app_env == "local" else None
|
||||
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
app = FastAPI(
|
||||
title="AutoGPT Agent Server",
|
||||
description=(
|
||||
|
||||
@@ -12,7 +12,7 @@ from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import ExecutionSubscription, Methods, WsMessage
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import Config, Settings
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -28,7 +28,7 @@ async def lifespan(app: FastAPI):
|
||||
event_queue.close()
|
||||
|
||||
|
||||
docs_url = "/docs" if settings.config.app_env == "local" else None
|
||||
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
event_queue = RedisEventQueue()
|
||||
_connection_manager = None
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
from backend.util.settings import AppEnvironment, BehaveAs, Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def configure_logging():
|
||||
@@ -6,7 +8,10 @@ def configure_logging():
|
||||
|
||||
import autogpt_libs.logging.config
|
||||
|
||||
if os.getenv("APP_ENV") != "cloud":
|
||||
if (
|
||||
settings.config.behave_as == BehaveAs.LOCAL
|
||||
or settings.config.app_env == AppEnvironment.LOCAL
|
||||
):
|
||||
autogpt_libs.logging.config.configure_logging(force_cloud_logging=False)
|
||||
else:
|
||||
autogpt_libs.logging.config.configure_logging(force_cloud_logging=True)
|
||||
|
||||
@@ -22,6 +22,11 @@ class AppEnvironment(str, Enum):
|
||||
PRODUCTION = "prod"
|
||||
|
||||
|
||||
class BehaveAs(str, Enum):
|
||||
LOCAL = "local"
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class UpdateTrackingModel(BaseModel, Generic[T]):
|
||||
_updated_fields: Set[str] = PrivateAttr(default_factory=set)
|
||||
|
||||
@@ -130,7 +135,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
|
||||
app_env: AppEnvironment = Field(
|
||||
default=AppEnvironment.LOCAL,
|
||||
description="The name of the app environment.",
|
||||
description="The name of the app environment: local or dev or prod",
|
||||
)
|
||||
|
||||
behave_as: BehaveAs = Field(
|
||||
default=BehaveAs.LOCAL,
|
||||
description="What environment to behave as: local or cloud",
|
||||
)
|
||||
|
||||
backend_cors_allow_origins: List[str] = Field(default_factory=list)
|
||||
|
||||
Reference in New Issue
Block a user