Compare commits

...

2 Commits

Author SHA1 Message Date
Brandon Rising
a828ea5de9 Allow optional base model lists to be passed in argparse 2024-03-11 19:00:46 -04:00
Brandon Rising
628639c565 Remove ability to pass remote_api_tokens via the CLI directly 2024-03-11 16:37:14 -04:00

View File

@@ -11,11 +11,12 @@ the command line.
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import json
import os import os
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, Type, get_args, get_origin, get_type_hints
from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf
from pydantic import BaseModel from pydantic import BaseModel
@@ -23,6 +24,23 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
class ParseModelListAction(argparse.Action):
"""An argparse action that parses a JSON string into a list of Pydantic models."""
model_type: Type[BaseModel]
def __init__(self, model_type: Type[BaseModel], *args, **kwargs): # type: ignore
super(ParseModelListAction, self).__init__(*args, **kwargs) # type: ignore
self.model_type = model_type
def __call__(self, parser, namespace, values, option_string=None): # type: ignore
try:
items_data = json.loads(values) # type: ignore
items = [self.model_type(**item_data) for item_data in items_data]
setattr(namespace, self.dest, items)
except Exception as e:
parser.error(f"Could not parse models: {e}")
class InvokeAISettings(BaseSettings): class InvokeAISettings(BaseSettings):
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file.""" """Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
@@ -194,7 +212,28 @@ class InvokeAISettings(BaseSettings):
else: else:
argparse_group = command_parser argparse_group = command_parser
if get_origin(field_type) == Literal: def matches_optional_list_of_basemodel_subclasses(field_type):
args = get_args(field_type)
for arg in args:
list_origin = get_origin(arg)
if list_origin is list:
list_args = get_args(arg)
if len(list_args) == 1 and issubclass(list_args[0], BaseModel):
return list_args[0]
return None
if name == "remote_api_tokens":
pass
if bm_type:=matches_optional_list_of_basemodel_subclasses(field_type):
argparse_group.add_argument(
f"--{name}",
dest=name,
action=ParseModelListAction,
model_type=bm_type,
type=str,
default=default,
help=field.description,
)
elif get_origin(field_type) == Literal:
allowed_values = get_args(field.annotation) allowed_values = get_args(field.annotation)
allowed_types = set() allowed_types = set()
for val in allowed_values: for val in allowed_values: