mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
add config_management to context; config web settings
This commit is contained in:
@@ -12,12 +12,13 @@ from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.schema import schema
|
||||
|
||||
from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions, models
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.config_management import get_configuration
|
||||
from .services.app_settings import InvokeAIWebConfig, InvokeAIAppConfig
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
@@ -33,15 +34,15 @@ app.add_middleware(
|
||||
middleware_id=event_handler_id,
|
||||
)
|
||||
|
||||
# Add CORS
|
||||
# TODO: use configuration for this
|
||||
origins = []
|
||||
# Add CORS, using the web configuration stanza in `invokeai.yaml`
|
||||
web_conf = get_configuration(InvokeAIWebConfig)
|
||||
web_conf.parse_args()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_origins=web_conf.allow_origins,
|
||||
allow_credentials=web_conf.allow_credentials,
|
||||
allow_methods=web_conf.allow_methods,
|
||||
allow_headers=web_conf.allow_headers,
|
||||
)
|
||||
|
||||
socket_io = SocketIO(app)
|
||||
@@ -52,7 +53,7 @@ config = {}
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
config = Args()
|
||||
config = get_configuration(InvokeAIAppConfig)
|
||||
config.parse_args()
|
||||
|
||||
ApiDependencies.initialize(
|
||||
|
||||
@@ -17,7 +17,6 @@ from .services.default_graphs import create_system_graphs
|
||||
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations import *
|
||||
@@ -33,7 +32,8 @@ from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
from .services.config_management import get_app_config
|
||||
from .services.config_management import get_configuration
|
||||
from .services.app_settings import InvokeAIAppConfig
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
@@ -188,7 +188,7 @@ def invoke_all(context: CliContext):
|
||||
|
||||
|
||||
def invoke_cli():
|
||||
config = get_app_config()
|
||||
config = get_configuration(InvokeAIAppConfig)
|
||||
config.parse_args()
|
||||
|
||||
model_manager = get_model_manager(config)
|
||||
@@ -197,7 +197,7 @@ def invoke_cli():
|
||||
# Currently nothing is done with the returned Completer
|
||||
# object, but the object can be used to change autocompletion
|
||||
# behavior on the fly, if desired.
|
||||
completer = set_autocompleter(model_manager)
|
||||
set_autocompleter(model_manager)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
@@ -220,6 +220,7 @@ def invoke_cli():
|
||||
),
|
||||
processor=DefaultInvocationProcessor(),
|
||||
restoration=RestorationServices(config),
|
||||
configuration=config,
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
@@ -259,6 +260,7 @@ def invoke_cli():
|
||||
|
||||
# Parse args to create invocation
|
||||
args = vars(context.parser.parse_args(shlex.split(cmd.strip())))
|
||||
print(f'DEBUG: cmd={cmd}, args={args}')
|
||||
|
||||
# Override defaults
|
||||
for field_name, field_default in context.defaults.items():
|
||||
|
||||
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, BaseSettings, Field
|
||||
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.config_management import InvokeAISettings
|
||||
@@ -36,7 +36,7 @@ class BaseInvocationOutput(InvokeAISettings):
|
||||
return tuple(subclasses)
|
||||
|
||||
|
||||
class BaseInvocation(ABC, BaseModel):
|
||||
class BaseInvocation(ABC, InvokeAISettings):
|
||||
"""A node to process inputs and produce outputs.
|
||||
May use dependency injection in __init__ to receive providers.
|
||||
"""
|
||||
@@ -101,8 +101,8 @@ class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
|
||||
|
||||
class InvocationConfig(BaseModel.Config):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
class InvocationConfig(BaseSettings.Config):
|
||||
"""Customizes pydantic's BaseSettings.Config class for use by Invocations.
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
|
||||
|
||||
@@ -77,6 +77,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# Handle invalid model parameter
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
print(f'DEBUG: steps = {self.steps}')
|
||||
|
||||
outputs = Txt2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=partial(self.dispatch_progress, context),
|
||||
|
||||
68
invokeai/app/services/app_settings.py
Normal file
68
invokeai/app/services/app_settings.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, List
|
||||
from .config_management import InvokeAISettings, get_configuration
|
||||
|
||||
class InvokeAIWebConfig(InvokeAISettings):
|
||||
'''
|
||||
Web-specific settings
|
||||
'''
|
||||
#fmt: off
|
||||
type : Literal["web"] = "web"
|
||||
allow_origins : List = Field(default=[], description='Allowed CORS origins')
|
||||
allow_credentials : bool = Field(default=True, description='Allow CORS credentials')
|
||||
allow_methods : List = Field(default=["*"], description='Methods allowed for CORS')
|
||||
allow_headers : List = Field(default=["*"], description='Headers allowed for CORS')
|
||||
#fmt: on
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
'''
|
||||
Application-wide settings, not associated with any Invocation
|
||||
'''
|
||||
#fmt: off
|
||||
type: Literal["app_settings"] = "app_settings"
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = 'float16'
|
||||
conf : Path = Field(default='configs/models.yaml', description='Path to models definition file')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images')
|
||||
root : Path = Field(default='~/invokeai', description='InvokeAI runtime root directory')
|
||||
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI embeddings directory')
|
||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.')
|
||||
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.')
|
||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention")
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements")
|
||||
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching")
|
||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker")
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code")
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code")
|
||||
#fmt: on
|
||||
|
||||
@property
|
||||
def root_dir(self)->Path:
|
||||
return self.root.expanduser()
|
||||
|
||||
def _resolve(self,partial_path:Path)->Path:
|
||||
return (self.root_dir / partial_path).resolve()
|
||||
|
||||
@property
|
||||
def output_path(self)->Path:
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def model_conf_path(self)->Path:
|
||||
return self._resolve(self.conf)
|
||||
|
||||
@property
|
||||
def embedding_path(self)->Path:
|
||||
return self._resolve(self.embedding_dir) if self.embedding_dir else None
|
||||
|
||||
@property
|
||||
def autoconvert_path(self)->Path:
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
@property
|
||||
def gfpgan_model_path(self)->Path:
|
||||
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
|
||||
|
||||
@@ -44,6 +44,7 @@ print(conf.precision)
|
||||
'''
|
||||
import argparse
|
||||
import os
|
||||
import uuid
|
||||
from argparse import ArgumentParser
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from pathlib import Path
|
||||
@@ -51,7 +52,6 @@ from pydantic import BaseSettings, Field
|
||||
from typing import Any, ClassVar, List, Literal, Union, get_origin, get_type_hints, get_args
|
||||
|
||||
INIT_FILE = Path('invokeai.yaml')
|
||||
_invokeai_app_config = None
|
||||
|
||||
def _get_root_directory()->Path:
|
||||
root = None
|
||||
@@ -85,14 +85,18 @@ class InvokeAISettings(BaseSettings):
|
||||
prog=__name__,
|
||||
description='InvokeAI application',
|
||||
)
|
||||
env_prefix = self.Config.env_prefix
|
||||
default_settings_stanza = get_args(get_type_hints(self)['type'])[0]
|
||||
initconf = self.initconf.get(default_settings_stanza) if self.initconf and default_settings_stanza in self.initconf else None
|
||||
|
||||
fields = self.__fields__
|
||||
for name, field in fields.items():
|
||||
if name not in self._excluded():
|
||||
env_name = env_prefix+name
|
||||
if initconf and name in initconf:
|
||||
field.default = initconf.get(name)
|
||||
field.default = initconf.get(name)
|
||||
if env_name in os.environ:
|
||||
field.default = os.environ[env_name]
|
||||
add_field_argument(parser, name, field)
|
||||
return parser
|
||||
|
||||
@@ -104,7 +108,8 @@ class InvokeAISettings(BaseSettings):
|
||||
env_file_encoding = 'utf-8'
|
||||
arbitrary_types_allowed = True
|
||||
env_prefix = 'INVOKEAI_'
|
||||
class_sensitive = False
|
||||
extra = 'allow'
|
||||
class_sensitive = True
|
||||
@classmethod
|
||||
def customise_sources(
|
||||
cls,
|
||||
@@ -114,7 +119,7 @@ class InvokeAISettings(BaseSettings):
|
||||
):
|
||||
return (
|
||||
init_settings,
|
||||
InvokeAIAppConfig._omegaconf_settings_source,
|
||||
InvokeAISettings._omegaconf_settings_source,
|
||||
env_settings,
|
||||
file_secret_settings,
|
||||
)
|
||||
@@ -128,64 +133,6 @@ class InvokeAISettings(BaseSettings):
|
||||
else:
|
||||
return {}
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
'''
|
||||
Application-wide settings, not associated with any Invocation
|
||||
'''
|
||||
type: Literal["app_settings"] = "app_settings"
|
||||
precision : Literal[tuple(['auto','float16','float32','autocast'])] = 'float16'
|
||||
conf : Path = Field(default='configs/models.yaml', description='Path to models definition file')
|
||||
outdir : Path = Field(default='outputs', description='Default folder for output images')
|
||||
root : Path = Field(default='~/invokeai', description='InvokeAI runtime root directory')
|
||||
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI embeddings directory')
|
||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.')
|
||||
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.')
|
||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory')
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention")
|
||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements")
|
||||
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching")
|
||||
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker")
|
||||
restore : bool = Field(default=True, description="Enable/disable face restoration code")
|
||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code")
|
||||
|
||||
@property
|
||||
def root_dir(self)->Path:
|
||||
return self.root.expanduser()
|
||||
|
||||
def _resolve(self,partial_path:Path)->Path:
|
||||
return (self.root_dir / partial_path).resolve()
|
||||
|
||||
@property
|
||||
def output_path(self)->Path:
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def model_conf_path(self)->Path:
|
||||
return self._resolve(self.conf)
|
||||
|
||||
@property
|
||||
def embedding_path(self)->Path:
|
||||
return self._resolve(self.embedding_dir) if self.embedding_dir else None
|
||||
|
||||
@property
|
||||
def autoconvert_path(self)->Path:
|
||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||
|
||||
@property
|
||||
def gfpgan_model_path(self)->Path:
|
||||
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
|
||||
|
||||
def get_app_config(root: Path = _get_root_directory())->InvokeAIAppConfig:
|
||||
global _invokeai_app_config
|
||||
if not _invokeai_app_config:
|
||||
conf_file = root / INIT_FILE
|
||||
try:
|
||||
InvokeAIAppConfig.conf = OmegaConf.load(conf_file)
|
||||
except OSError as e:
|
||||
print(f'** Initialization file could not be read. {str(e)}')
|
||||
_invokeai_app_config = InvokeAIAppConfig()
|
||||
return _invokeai_app_config
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if get_origin(field.type_) == Literal:
|
||||
@@ -214,3 +161,16 @@ def add_field_argument(command_parser, name: str, field, default_override = None
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
def get_configuration(
|
||||
object_type: InvokeAISettings,
|
||||
root: Path = _get_root_directory(),
|
||||
)->InvokeAISettings:
|
||||
conf_file = root / INIT_FILE
|
||||
try: # setting shared class variable
|
||||
InvokeAISettings.initconf = OmegaConf.load(conf_file)
|
||||
except OSError as e:
|
||||
print(f'** Initialization file could not be read. {str(e)}')
|
||||
return object_type(id=uuid.uuid4().hex) \
|
||||
if 'id' in get_type_hints(object_type) \
|
||||
else object_type()
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from .image_storage import ImageStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
from .item_storage import ItemStorageABC
|
||||
from .config_management import InvokeAISettings
|
||||
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
@@ -16,6 +17,7 @@ class InvocationServices:
|
||||
images: ImageStorageBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
configuration: InvokeAISettings
|
||||
restoration: RestorationServices
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
@@ -34,6 +36,7 @@ class InvocationServices:
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
configuration: InvokeAISettings,
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
@@ -44,3 +47,4 @@ class InvocationServices:
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
self.configuration = configuration
|
||||
|
||||
@@ -7,13 +7,13 @@ from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.version
|
||||
from .config_management import InvokeAIAppConfig
|
||||
from .config_management import InvokeAISettings
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: Replace with an abstract class base ModelManagerBase
|
||||
def get_model_manager(config:InvokeAIAppConfig) -> ModelManager:
|
||||
def get_model_manager(config:InvokeAISettings) -> ModelManager:
|
||||
model_config = config.model_conf_path
|
||||
if not model_config.exists():
|
||||
report_model_error(
|
||||
|
||||
Reference in New Issue
Block a user