add config_management to context; config web settings

This commit is contained in:
Lincoln Stein
2023-04-18 15:33:36 -04:00
parent 196c21b7c9
commit 57d032d7e6
8 changed files with 118 additions and 81 deletions

View File

@@ -12,12 +12,13 @@ from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from ..backend import Args
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models
from .api.sockets import SocketIO
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.config_management import get_configuration
from .services.app_settings import InvokeAIWebConfig, InvokeAIAppConfig
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
@@ -33,15 +34,15 @@ app.add_middleware(
middleware_id=event_handler_id,
)
# Add CORS
# TODO: use configuration for this
origins = []
# Add CORS, using the web configuration stanza in `invokeai.yaml`
web_conf = get_configuration(InvokeAIWebConfig)
web_conf.parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_origins=web_conf.allow_origins,
allow_credentials=web_conf.allow_credentials,
allow_methods=web_conf.allow_methods,
allow_headers=web_conf.allow_headers,
)
socket_io = SocketIO(app)
@@ -52,7 +53,7 @@ config = {}
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event():
config = Args()
config = get_configuration(InvokeAIAppConfig)
config.parse_args()
ApiDependencies.initialize(

View File

@@ -17,7 +17,6 @@ from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
from .cli.completer import set_autocompleter
from .invocations import *
@@ -33,7 +32,8 @@ from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage
from .services.config_management import get_app_config
from .services.config_management import get_configuration
from .services.app_settings import InvokeAIAppConfig
class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@@ -188,7 +188,7 @@ def invoke_all(context: CliContext):
def invoke_cli():
config = get_app_config()
config = get_configuration(InvokeAIAppConfig)
config.parse_args()
model_manager = get_model_manager(config)
@@ -197,7 +197,7 @@ def invoke_cli():
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
completer = set_autocompleter(model_manager)
set_autocompleter(model_manager)
events = EventServiceBase()
@@ -220,6 +220,7 @@ def invoke_cli():
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config),
configuration=config,
)
system_graphs = create_system_graphs(services.graph_library)
@@ -259,6 +260,7 @@ def invoke_cli():
# Parse args to create invocation
args = vars(context.parser.parse_args(shlex.split(cmd.strip())))
print(f'DEBUG: cmd={cmd}, args={args}')
# Override defaults
for field_name, field_default in context.defaults.items():

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
from pydantic import BaseModel, Field
from pydantic import BaseModel, BaseSettings, Field
from ..services.invocation_services import InvocationServices
from ..services.config_management import InvokeAISettings
@@ -36,7 +36,7 @@ class BaseInvocationOutput(InvokeAISettings):
return tuple(subclasses)
class BaseInvocation(ABC, BaseModel):
class BaseInvocation(ABC, InvokeAISettings):
"""A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
"""
@@ -101,8 +101,8 @@ class CustomisedSchemaExtra(TypedDict):
ui: UIConfig
class InvocationConfig(BaseModel.Config):
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
class InvocationConfig(BaseSettings.Config):
"""Customizes pydantic's BaseSettings.Config class for use by Invocations.
Provide `schema_extra` a `ui` dict to add hints for generated UIs.

View File

@@ -77,6 +77,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
print(f'DEBUG: steps = {self.steps}')
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context),

View File

@@ -0,0 +1,68 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
from pydantic import Field
from pathlib import Path
from typing import Literal, List
from .config_management import InvokeAISettings, get_configuration
class InvokeAIWebConfig(InvokeAISettings):
'''
Web-specific settings
'''
#fmt: off
type : Literal["web"] = "web"
allow_origins : List = Field(default=[], description='Allowed CORS origins')
allow_credentials : bool = Field(default=True, description='Allow CORS credentials')
allow_methods : List = Field(default=["*"], description='Methods allowed for CORS')
allow_headers : List = Field(default=["*"], description='Headers allowed for CORS')
#fmt: on
class InvokeAIAppConfig(InvokeAISettings):
'''
Application-wide settings, not associated with any Invocation
'''
#fmt: off
type: Literal["app_settings"] = "app_settings"
precision : Literal[tuple(['auto','float16','float32','autocast'])] = 'float16'
conf : Path = Field(default='configs/models.yaml', description='Path to models definition file')
outdir : Path = Field(default='outputs', description='Default folder for output images')
root : Path = Field(default='~/invokeai', description='InvokeAI runtime root directory')
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI embeddings directory')
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.')
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.')
embeddings : bool = Field(default=True, description='Load contents of embeddings directory')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention")
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements")
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching")
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker")
restore : bool = Field(default=True, description="Enable/disable face restoration code")
esrgan : bool = Field(default=True, description="Enable/disable upscaling code")
#fmt: on
@property
def root_dir(self)->Path:
return self.root.expanduser()
def _resolve(self,partial_path:Path)->Path:
return (self.root_dir / partial_path).resolve()
@property
def output_path(self)->Path:
return self._resolve(self.outdir)
@property
def model_conf_path(self)->Path:
return self._resolve(self.conf)
@property
def embedding_path(self)->Path:
return self._resolve(self.embedding_dir) if self.embedding_dir else None
@property
def autoconvert_path(self)->Path:
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
@property
def gfpgan_model_path(self)->Path:
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None

View File

@@ -44,6 +44,7 @@ print(conf.precision)
'''
import argparse
import os
import uuid
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig
from pathlib import Path
@@ -51,7 +52,6 @@ from pydantic import BaseSettings, Field
from typing import Any, ClassVar, List, Literal, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path('invokeai.yaml')
_invokeai_app_config = None
def _get_root_directory()->Path:
root = None
@@ -85,14 +85,18 @@ class InvokeAISettings(BaseSettings):
prog=__name__,
description='InvokeAI application',
)
env_prefix = self.Config.env_prefix
default_settings_stanza = get_args(get_type_hints(self)['type'])[0]
initconf = self.initconf.get(default_settings_stanza) if self.initconf and default_settings_stanza in self.initconf else None
fields = self.__fields__
for name, field in fields.items():
if name not in self._excluded():
env_name = env_prefix+name
if initconf and name in initconf:
field.default = initconf.get(name)
field.default = initconf.get(name)
if env_name in os.environ:
field.default = os.environ[env_name]
add_field_argument(parser, name, field)
return parser
@@ -104,7 +108,8 @@ class InvokeAISettings(BaseSettings):
env_file_encoding = 'utf-8'
arbitrary_types_allowed = True
env_prefix = 'INVOKEAI_'
class_sensitive = False
extra = 'allow'
class_sensitive = True
@classmethod
def customise_sources(
cls,
@@ -114,7 +119,7 @@ class InvokeAISettings(BaseSettings):
):
return (
init_settings,
InvokeAIAppConfig._omegaconf_settings_source,
InvokeAISettings._omegaconf_settings_source,
env_settings,
file_secret_settings,
)
@@ -128,64 +133,6 @@ class InvokeAISettings(BaseSettings):
else:
return {}
class InvokeAIAppConfig(InvokeAISettings):
'''
Application-wide settings, not associated with any Invocation
'''
type: Literal["app_settings"] = "app_settings"
precision : Literal[tuple(['auto','float16','float32','autocast'])] = 'float16'
conf : Path = Field(default='configs/models.yaml', description='Path to models definition file')
outdir : Path = Field(default='outputs', description='Default folder for output images')
root : Path = Field(default='~/invokeai', description='InvokeAI runtime root directory')
embedding_dir : Path = Field(default='embeddings', description='Path to InvokeAI embeddings directory')
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.')
gfpgan_model_dir : Path = Field(default="./models/gfpgan/GFPGANv1.4.pth", description='Path to GFPGAN models directory.')
embeddings : bool = Field(default=True, description='Load contents of embeddings directory')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention")
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements")
max_loaded_models : int = Field(default=2, gt=0, description="Maximum number of models to keep in memory for rapid switching")
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker")
restore : bool = Field(default=True, description="Enable/disable face restoration code")
esrgan : bool = Field(default=True, description="Enable/disable upscaling code")
@property
def root_dir(self)->Path:
return self.root.expanduser()
def _resolve(self,partial_path:Path)->Path:
return (self.root_dir / partial_path).resolve()
@property
def output_path(self)->Path:
return self._resolve(self.outdir)
@property
def model_conf_path(self)->Path:
return self._resolve(self.conf)
@property
def embedding_path(self)->Path:
return self._resolve(self.embedding_dir) if self.embedding_dir else None
@property
def autoconvert_path(self)->Path:
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
@property
def gfpgan_model_path(self)->Path:
return self._resolve(self.gfpgan_model_dir) if self.gfpgan_model_dir else None
def get_app_config(root: Path = _get_root_directory())->InvokeAIAppConfig:
global _invokeai_app_config
if not _invokeai_app_config:
conf_file = root / INIT_FILE
try:
InvokeAIAppConfig.conf = OmegaConf.load(conf_file)
except OSError as e:
print(f'** Initialization file could not be read. {str(e)}')
_invokeai_app_config = InvokeAIAppConfig()
return _invokeai_app_config
def add_field_argument(command_parser, name: str, field, default_override = None):
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if get_origin(field.type_) == Literal:
@@ -214,3 +161,16 @@ def add_field_argument(command_parser, name: str, field, default_override = None
help=field.field_info.description,
)
def get_configuration(
object_type: InvokeAISettings,
root: Path = _get_root_directory(),
)->InvokeAISettings:
conf_file = root / INIT_FILE
try: # setting shared class variable
InvokeAISettings.initconf = OmegaConf.load(conf_file)
except OSError as e:
print(f'** Initialization file could not be read. {str(e)}')
return object_type(id=uuid.uuid4().hex) \
if 'id' in get_type_hints(object_type) \
else object_type()

View File

@@ -7,6 +7,7 @@ from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from .config_management import InvokeAISettings
class InvocationServices:
"""Services that can be used by invocations"""
@@ -16,6 +17,7 @@ class InvocationServices:
images: ImageStorageBase
queue: InvocationQueueABC
model_manager: ModelManager
configuration: InvokeAISettings
restoration: RestorationServices
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
@@ -34,6 +36,7 @@ class InvocationServices:
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
configuration: InvokeAISettings,
):
self.model_manager = model_manager
self.events = events
@@ -44,3 +47,4 @@ class InvocationServices:
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration

View File

@@ -7,13 +7,13 @@ from omegaconf import OmegaConf
from pathlib import Path
import invokeai.version
from .config_management import InvokeAIAppConfig
from .config_management import InvokeAISettings
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config:InvokeAIAppConfig) -> ModelManager:
def get_model_manager(config:InvokeAISettings) -> ModelManager:
model_config = config.model_conf_path
if not model_config.exists():
report_model_error(