diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 07f01188ac..0905fcf345 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -101,7 +101,7 @@ class ModelProbe(object): upcast_attention = (base_type==BaseModelType.StableDiffusion2 \ and prediction_type==SchedulerPredictionType.VPrediction), format = format, - image_size = 1024 if (base_type==BaseModelType.StableDiffusionXL) else \ + image_size = 1024 if (base_type in {BaseModelType.StableDiffusionXL,BaseModelType.StableDiffusionXLRefiner}) else \ 768 if (base_type==BaseModelType.StableDiffusion2 \ and prediction_type==SchedulerPredictionType.VPrediction ) else \ 512 @@ -366,7 +366,9 @@ class PipelineFolderProbe(FolderProbeBase): return BaseModelType.StableDiffusion1 elif unet_conf['cross_attention_dim'] == 1024: return BaseModelType.StableDiffusion2 - elif unet_conf['cross_attention_dim'] in {1280,2048}: + elif unet_conf['cross_attention_dim'] == 1280: + return BaseModelType.StableDiffusionXLRefiner + elif unet_conf['cross_attention_dim'] == 2048: return BaseModelType.StableDiffusionXL else: raise ValueError(f'Unknown base model for {self.folder_path}') diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index aa94d640f4..6ad779aa90 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -3,7 +3,8 @@ from enum import Enum from pydantic import BaseModel from typing import Literal, get_origin from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException -from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model, StableDiffusionXLModel +from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model +from .sdxl import StableDiffusionXLModel from .vae import VaeModel from .lora import LoRAModel from .controlnet import ControlNetModel # TODO: @@ -32,6 +33,14 @@ MODEL_CLASSES = { ModelType.ControlNet: ControlNetModel, ModelType.TextualInversion: TextualInversionModel, }, + BaseModelType.StableDiffusionXLRefiner: { + ModelType.Main: StableDiffusionXLModel, + ModelType.Vae: VaeModel, + # will not work until support written + ModelType.Lora: LoRAModel, + ModelType.ControlNet: ControlNetModel, + ModelType.TextualInversion: TextualInversionModel, + }, #BaseModelType.Kandinsky2_1: { # ModelType.Main: Kandinsky2_1Model, # ModelType.MoVQ: MoVQModel, diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 73cbb8eb3e..c4202052cc 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -22,6 +22,7 @@ class BaseModelType(str, Enum): StableDiffusion1 = "sd-1" StableDiffusion2 = "sd-2" StableDiffusionXL = "sdxl" + StableDiffusionXLRefiner = "sdxl-refiner" #Kandinsky2_1 = "kandinsky-2.1" class ModelType(str, Enum): diff --git a/invokeai/backend/model_management/models/sdxl.py b/invokeai/backend/model_management/models/sdxl.py new file mode 100644 index 0000000000..76cabcdc62 --- /dev/null +++ b/invokeai/backend/model_management/models/sdxl.py @@ -0,0 +1,114 @@ +import os +import json +from enum import Enum +from pydantic import Field +from typing import Literal, Optional +from .base import ( + ModelConfigBase, + BaseModelType, + ModelType, + ModelVariantType, + DiffusersModel, + read_checkpoint_meta, + classproperty, +) +from omegaconf import OmegaConf + +class StableDiffusionXLModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" + +class StableDiffusionXLModel(DiffusersModel): + + # TODO: check that configs overwriten properly + class DiffusersConfig(ModelConfigBase): + model_format: Literal[StableDiffusionXLModelFormat.Diffusers] + vae: Optional[str] = Field(None) + variant: ModelVariantType + + class CheckpointConfig(ModelConfigBase): + model_format: Literal[StableDiffusionXLModelFormat.Checkpoint] + vae: Optional[str] = Field(None) + config: str + variant: ModelVariantType + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner} + assert model_type == ModelType.Main + super().__init__( + model_path=model_path, + base_model=BaseModelType.StableDiffusionXL, + model_type=ModelType.Main, + ) + + @classmethod + def probe_config(cls, path: str, **kwargs): + model_format = cls.detect_format(path) + ckpt_config_path = kwargs.get("config", None) + if model_format == StableDiffusionXLModelFormat.Checkpoint: + if ckpt_config_path: + ckpt_config = OmegaConf.load(ckpt_config_path) + ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] + + else: + checkpoint = read_checkpoint_meta(path) + checkpoint = checkpoint.get('state_dict', checkpoint) + in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] + + elif model_format == StableDiffusionXLModelFormat.Diffusers: + unet_config_path = os.path.join(path, "unet", "config.json") + if os.path.exists(unet_config_path): + with open(unet_config_path, "r") as f: + unet_config = json.loads(f.read()) + in_channels = unet_config['in_channels'] + + else: + raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") + + else: + raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}") + + if in_channels == 9: + variant = ModelVariantType.Inpaint + elif in_channels == 5: + variant = ModelVariantType.Depth + elif in_channels == 4: + variant = ModelVariantType.Normal + else: + raise Exception("Unkown stable diffusion 2.* model format") + + if ckpt_config_path is None: + # TO DO: implement picking + pass + + return cls.create_config( + path=path, + model_format=model_format, + + config=ckpt_config_path, + variant=variant, + ) + + @classproperty + def save_to_config(cls) -> bool: + return True + + @classmethod + def detect_format(cls, model_path: str): + if os.path.isdir(model_path): + return StableDiffusionXLModelFormat.Diffusers + else: + return StableDiffusionXLModelFormat.Checkpoint + + @classmethod + def convert_if_required( + cls, + model_path: str, + output_path: str, + config: ModelConfigBase, + base_model: BaseModelType, + ) -> str: + if isinstance(config, cls.CheckpointConfig): + raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported') + else: + return model_path diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index c0b43d6774..f29581ab1e 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -5,14 +5,11 @@ from pydantic import Field from pathlib import Path from typing import Literal, Optional, Union from .base import ( - ModelBase, ModelConfigBase, BaseModelType, ModelType, - SubModelType, ModelVariantType, DiffusersModel, - SchedulerPredictionType, SilenceWarnings, read_checkpoint_meta, classproperty, @@ -222,105 +219,6 @@ class StableDiffusion2Model(DiffusersModel): else: return model_path -class StableDiffusionXLModelFormat(str, Enum): - Checkpoint = "checkpoint" - Diffusers = "diffusers" - -class StableDiffusionXLModel(DiffusersModel): - - # TODO: check that configs overwriten properly - class DiffusersConfig(ModelConfigBase): - model_format: Literal[StableDiffusionXLModelFormat.Diffusers] - vae: Optional[str] = Field(None) - variant: ModelVariantType - - class CheckpointConfig(ModelConfigBase): - model_format: Literal[StableDiffusionXLModelFormat.Checkpoint] - vae: Optional[str] = Field(None) - config: str - variant: ModelVariantType - - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): - assert base_model == BaseModelType.StableDiffusionXL - assert model_type == ModelType.Main - super().__init__( - model_path=model_path, - base_model=BaseModelType.StableDiffusionXL, - model_type=ModelType.Main, - ) - - @classmethod - def probe_config(cls, path: str, **kwargs): - model_format = cls.detect_format(path) - ckpt_config_path = kwargs.get("config", None) - if model_format == StableDiffusionXLModelFormat.Checkpoint: - if ckpt_config_path: - ckpt_config = OmegaConf.load(ckpt_config_path) - ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] - - else: - checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get('state_dict', checkpoint) - in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - - elif model_format == StableDiffusionXLModelFormat.Diffusers: - unet_config_path = os.path.join(path, "unet", "config.json") - if os.path.exists(unet_config_path): - with open(unet_config_path, "r") as f: - unet_config = json.loads(f.read()) - in_channels = unet_config['in_channels'] - - else: - raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") - - else: - raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}") - - if in_channels == 9: - variant = ModelVariantType.Inpaint - elif in_channels == 5: - variant = ModelVariantType.Depth - elif in_channels == 4: - variant = ModelVariantType.Normal - else: - raise Exception("Unkown stable diffusion 2.* model format") - - if ckpt_config_path is None: - ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusionXL, variant) - - return cls.create_config( - path=path, - model_format=model_format, - - config=ckpt_config_path, - variant=variant, - ) - - @classproperty - def save_to_config(cls) -> bool: - return True - - @classmethod - def detect_format(cls, model_path: str): - if os.path.isdir(model_path): - return StableDiffusionXLModelFormat.Diffusers - else: - return StableDiffusionXLModelFormat.Checkpoint - - @classmethod - def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if isinstance(config, cls.CheckpointConfig): - raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported') - else: - return model_path - - def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): ckpt_configs = { BaseModelType.StableDiffusion1: { @@ -355,7 +253,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): # Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models def _convert_ckpt_and_cache( version: BaseModelType, - model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig, StableDiffusionXLModel.CheckpointConfig], + model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig], output_path: str, ) -> str: """