mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor(mm): make config classes narrow
Simpler logic to identify, less complexity to add new model, fewer useless attrs that do not relate to the model arch, etc
This commit is contained in:
@@ -17,8 +17,8 @@ from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import (
|
||||
IPAdapter_InvokeAI_Config_Base,
|
||||
IPAdapterCheckpointConfig,
|
||||
IPAdapterInvokeAIConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
@@ -68,7 +68,7 @@ class FluxIPAdapterInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
|
||||
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
|
||||
|
||||
# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
|
||||
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||
|
||||
@@ -13,8 +13,8 @@ from invokeai.app.services.model_records.model_records_base import ModelRecordCh
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
IPAdapter_InvokeAI_Config_Base,
|
||||
IPAdapterCheckpointConfig,
|
||||
IPAdapterInvokeAIConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.starter_models import (
|
||||
StarterModel,
|
||||
@@ -123,9 +123,9 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
|
||||
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapterCheckpointConfig))
|
||||
|
||||
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
|
||||
if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base):
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
else:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,9 +12,7 @@ from typing import Any, Dict, Generator, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
|
||||
|
||||
@@ -7,7 +7,7 @@ from diffusers import ControlNetModel
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
ControlNetCheckpointConfig,
|
||||
ControlNet_Checkpoint_Config_Base,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
@@ -46,7 +46,7 @@ class ControlNetLoader(GenericDiffusersLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, ControlNetCheckpointConfig):
|
||||
if isinstance(config, ControlNet_Checkpoint_Config_Base):
|
||||
return ControlNetModel.from_single_file(
|
||||
config.path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
|
||||
@@ -37,26 +37,26 @@ from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
ControlNetCheckpointConfig,
|
||||
ControlNetDiffusersConfig,
|
||||
FLUX_Quantized_BnB_NF4_CheckpointConfig,
|
||||
FLUX_Quantized_GGUF_CheckpointConfig,
|
||||
FLUX_Unquantized_CheckpointConfig,
|
||||
FluxReduxConfig,
|
||||
IPAdapterCheckpointConfig,
|
||||
T5EncoderBnbQuantizedLlmInt8bConfig,
|
||||
T5EncoderConfig,
|
||||
VAECheckpointConfig,
|
||||
CLIPEmbed_Diffusers_Config_Base,
|
||||
ControlNet_Checkpoint_Config_Base,
|
||||
ControlNet_Diffusers_Config_Base,
|
||||
FLUXRedux_Checkpoint_Config,
|
||||
IPAdapter_Checkpoint_Config_Base,
|
||||
Main_FLUX_BnBNF4_Config,
|
||||
Main_FLUX_Checkpoint_Config,
|
||||
Main_FLUX_GGUF_Config,
|
||||
T5Encoder_BnBLLMint8_Config,
|
||||
T5Encoder_T5Encoder_Config,
|
||||
VAE_Checkpoint_Config_Base,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
FluxVariantType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
@@ -86,7 +86,7 @@ class FluxVAELoader(ModelLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, VAECheckpointConfig):
|
||||
if not isinstance(config, VAE_Checkpoint_Config_Base):
|
||||
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
|
||||
model_path = Path(config.path)
|
||||
|
||||
@@ -116,7 +116,7 @@ class CLIPDiffusersLoader(ModelLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, CLIPEmbedDiffusersConfig):
|
||||
if not isinstance(config, CLIPEmbed_Diffusers_Config_Base):
|
||||
raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
@@ -139,7 +139,7 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig):
|
||||
if not isinstance(config, T5Encoder_BnBLLMint8_Config):
|
||||
raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.")
|
||||
if not bnb_available:
|
||||
raise ImportError(
|
||||
@@ -186,7 +186,7 @@ class T5EncoderCheckpointModel(ModelLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, T5EncoderConfig):
|
||||
if not isinstance(config, T5Encoder_T5Encoder_Config):
|
||||
raise ValueError("Only T5EncoderConfig models are currently supported here.")
|
||||
|
||||
match submodel_type:
|
||||
@@ -226,7 +226,7 @@ class FluxCheckpointModel(ModelLoader):
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, FLUX_Unquantized_CheckpointConfig)
|
||||
assert isinstance(config, Main_FLUX_Checkpoint_Config)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
@@ -268,7 +268,7 @@ class FluxGGUFCheckpointModel(ModelLoader):
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, FLUX_Quantized_GGUF_CheckpointConfig)
|
||||
assert isinstance(config, Main_FLUX_GGUF_Config)
|
||||
model_path = Path(config.path)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
@@ -314,7 +314,7 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
) -> AnyModel:
|
||||
assert isinstance(config, FLUX_Quantized_BnB_NF4_CheckpointConfig)
|
||||
assert isinstance(config, Main_FLUX_BnBNF4_Config)
|
||||
if not bnb_available:
|
||||
raise ImportError(
|
||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||
@@ -342,9 +342,9 @@ class FluxControlnetModel(ModelLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, ControlNetCheckpointConfig):
|
||||
if isinstance(config, ControlNet_Checkpoint_Config_Base):
|
||||
model_path = Path(config.path)
|
||||
elif isinstance(config, ControlNetDiffusersConfig):
|
||||
elif isinstance(config, ControlNet_Diffusers_Config_Base):
|
||||
# If this is a diffusers directory, we simply ignore the config file and load from the weight file.
|
||||
model_path = Path(config.path) / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
@@ -363,7 +363,7 @@ class FluxControlnetModel(ModelLoader):
|
||||
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
|
||||
with accelerate.init_empty_weights():
|
||||
# HACK(ryand): Is it safe to assume dev here?
|
||||
model = XLabsControlNetFlux(get_flux_transformers_params(ModelVariantType.FluxDev))
|
||||
model = XLabsControlNetFlux(get_flux_transformers_params(FluxVariantType.Dev))
|
||||
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
@@ -389,7 +389,7 @@ class FluxIpAdapterModel(ModelLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, IPAdapterCheckpointConfig):
|
||||
if not isinstance(config, IPAdapter_Checkpoint_Config_Base):
|
||||
raise ValueError(f"Unexpected model config type: {type(config)}.")
|
||||
|
||||
sd = load_file(Path(config.path))
|
||||
@@ -412,7 +412,7 @@ class FluxReduxModelLoader(ModelLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not isinstance(config, FluxReduxConfig):
|
||||
if not isinstance(config, FLUXRedux_Checkpoint_Config):
|
||||
raise ValueError(f"Unexpected model config type: {type(config)}.")
|
||||
|
||||
sd = load_file(Path(config.path))
|
||||
|
||||
@@ -15,8 +15,14 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
SD_1_2_XL_XLRefiner_CheckpointConfig,
|
||||
SD_1_2_XL_XLRefiner_DiffusersConfig,
|
||||
Main_SD1_Checkpoint_Config,
|
||||
Main_SD1_Diffusers_Config,
|
||||
Main_SD2_Checkpoint_Config,
|
||||
Main_SD2_Diffusers_Config,
|
||||
Main_SDXL_Checkpoint_Config,
|
||||
Main_SDXL_Diffusers_Config,
|
||||
Main_SDXLRefiner_Checkpoint_Config,
|
||||
Main_SDXLRefiner_Diffusers_Config,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
@@ -108,7 +114,19 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||
},
|
||||
}
|
||||
assert isinstance(config, (SD_1_2_XL_XLRefiner_DiffusersConfig, SD_1_2_XL_XLRefiner_CheckpointConfig))
|
||||
assert isinstance(
|
||||
config,
|
||||
(
|
||||
Main_SD1_Diffusers_Config,
|
||||
Main_SD2_Diffusers_Config,
|
||||
Main_SDXL_Diffusers_Config,
|
||||
Main_SDXLRefiner_Diffusers_Config,
|
||||
Main_SD1_Checkpoint_Config,
|
||||
Main_SD2_Checkpoint_Config,
|
||||
Main_SDXL_Checkpoint_Config,
|
||||
Main_SDXLRefiner_Checkpoint_Config,
|
||||
),
|
||||
)
|
||||
try:
|
||||
load_class = load_classes[config.base][config.variant]
|
||||
except KeyError as e:
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, VAECheckpointConfig
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, VAE_Checkpoint_Config_Base
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
@@ -27,7 +27,7 @@ class VAELoader(GenericDiffusersLoader):
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, VAECheckpointConfig):
|
||||
if isinstance(config, VAE_Checkpoint_Config_Base):
|
||||
return AutoencoderKL.from_single_file(
|
||||
config.path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
|
||||
@@ -21,7 +21,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ControlAdapterDefaultSettings,
|
||||
MainDiffusersConfig,
|
||||
MainModelDefaultSettings,
|
||||
TextualInversionFileConfig,
|
||||
TI_File_Config,
|
||||
VAEDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import ModelSourceType
|
||||
@@ -40,8 +40,8 @@ def store(
|
||||
return ModelRecordServiceSQL(db, logger)
|
||||
|
||||
|
||||
def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig:
|
||||
config = TextualInversionFileConfig(
|
||||
def example_ti_config(key: Optional[str] = None) -> TI_File_Config:
|
||||
config = TI_File_Config(
|
||||
source="test/source/",
|
||||
source_type=ModelSourceType.Path,
|
||||
path="/tmp/pokemon.bin",
|
||||
@@ -61,7 +61,7 @@ def test_type(store: ModelRecordServiceBase):
|
||||
config = example_ti_config("key1")
|
||||
store.add_model(config)
|
||||
config1 = store.get_model("key1")
|
||||
assert isinstance(config1, TextualInversionFileConfig)
|
||||
assert isinstance(config1, TI_File_Config)
|
||||
|
||||
|
||||
def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase):
|
||||
|
||||
Reference in New Issue
Block a user