mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): wip port of main models to new api
This commit is contained in:
@@ -1094,7 +1094,6 @@ MainDiffusers_SupportedBases: TypeAlias = Literal[
|
||||
BaseModelType.StableDiffusion3,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
BaseModelType.StableDiffusionXLRefiner,
|
||||
BaseModelType.Flux,
|
||||
BaseModelType.CogView4,
|
||||
]
|
||||
|
||||
@@ -1104,6 +1103,157 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
|
||||
base: MainDiffusers_SupportedBases = Field()
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
|
||||
_validate_is_dir(cls, mod)
|
||||
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.path / "config.json",
|
||||
{
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusion3Pipeline",
|
||||
"LatentConsistencyModelPipeline",
|
||||
"SD3Transformer2DModel",
|
||||
"CogView4Pipeline",
|
||||
},
|
||||
)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
|
||||
return cls(**fields, base=base)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
|
||||
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
|
||||
unet_config_path = mod.path / "unet" / "config.json"
|
||||
if unet_config_path.exists():
|
||||
with open(unet_config_path) as file:
|
||||
unet_conf = json.load(file)
|
||||
cross_attention_dim = unet_conf.get("cross_attention_dim")
|
||||
match cross_attention_dim:
|
||||
case 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
case 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
case 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
case 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized cross_attention_dim {cross_attention_dim}")
|
||||
|
||||
# Handle pipelines with a transformer (i.e. SD3).
|
||||
transformer_config_path = mod.path / "transformer" / "config.json"
|
||||
if transformer_config_path.exists():
|
||||
class_name = _get_class_name_from_config(cls, transformer_config_path)
|
||||
match class_name:
|
||||
case "SD3Transformer2DModel":
|
||||
return BaseModelType.StableDiffusion3
|
||||
case "CogView4Transformer2DModel":
|
||||
return BaseModelType.CogView4
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized transformer class name {class_name}")
|
||||
|
||||
raise NotAMatch(cls, "unable to determine base type")
|
||||
|
||||
@classmethod
|
||||
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> SchedulerPredictionType:
|
||||
if base not in {
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
}:
|
||||
raise ValueError(f"Attempted to get scheduler prediction type for non-UNet model base '{base}'")
|
||||
|
||||
scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json")
|
||||
|
||||
# TODO(psyche): Is epsilon the right default or should we raise if it's not present?
|
||||
prediction_type = scheduler_conf.get("prediction_type", "epsilon")
|
||||
|
||||
match prediction_type:
|
||||
case "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
case "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}")
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> ModelVariantType:
|
||||
if base not in {
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
}:
|
||||
raise ValueError(f"Attempted to get variant for model base '{base}' but it does not have variants")
|
||||
|
||||
unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json")
|
||||
in_channels = unet_config.get("in_channels")
|
||||
|
||||
match in_channels:
|
||||
case 4:
|
||||
return ModelVariantType.Normal
|
||||
case 5:
|
||||
if base is not BaseModelType.StableDiffusion2:
|
||||
raise NotAMatch(cls, "in_channels=5 is only valid for Stable Diffusion 2 models")
|
||||
return ModelVariantType.Depth
|
||||
case 9:
|
||||
return ModelVariantType.Inpaint
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels}")
|
||||
|
||||
@classmethod
|
||||
def _get_submodels_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> dict[SubModelType, SubmodelDefinition]:
|
||||
if base is not BaseModelType.StableDiffusion3:
|
||||
raise ValueError(f"Attempted to get submodels for non-SD3 model base '{base}'")
|
||||
|
||||
# Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
|
||||
config = _get_config_or_raise(cls, mod.path / "model_index.json")
|
||||
|
||||
submodels: dict[SubModelType, SubmodelDefinition] = {}
|
||||
|
||||
for key, value in config.items():
|
||||
# Anything that starts with an underscore is top-level metadata, not a submodel
|
||||
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
|
||||
continue
|
||||
# The key is something like "transformer" and is a submodel - it will be in a dir of the same name.
|
||||
# The value value is something like ["diffusers", "SD3Transformer2DModel"]
|
||||
_library_name, class_name = value
|
||||
|
||||
match class_name:
|
||||
case "CLIPTextModelWithProjection":
|
||||
model_type = ModelType.CLIPEmbed
|
||||
path_or_prefix = (mod.path / key).resolve().as_posix()
|
||||
|
||||
# We need to read the config to determine the variant of the CLIP model.
|
||||
clip_embed_config = _get_config_or_raise(cls, mod.path / key / "config.json")
|
||||
variant = _get_clip_variant_type_from_config(clip_embed_config)
|
||||
submodels[SubModelType(key)] = SubmodelDefinition(
|
||||
path_or_prefix=path_or_prefix,
|
||||
model_type=model_type,
|
||||
variant=variant,
|
||||
)
|
||||
case "SD3Transformer2DModel":
|
||||
model_type = ModelType.Main
|
||||
path_or_prefix = (mod.path / key).resolve().as_posix()
|
||||
variant = None
|
||||
submodels[SubModelType(key)] = SubmodelDefinition(
|
||||
path_or_prefix=path_or_prefix,
|
||||
model_type=model_type,
|
||||
variant=variant,
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
|
||||
return submodels
|
||||
|
||||
|
||||
class IPAdapterConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
|
||||
@@ -1231,6 +1381,20 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
|
||||
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
|
||||
|
||||
|
||||
def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None:
|
||||
try:
|
||||
hidden_size = config.get("hidden_size")
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
"""Model config for Clip Embeddings."""
|
||||
|
||||
@@ -1238,20 +1402,6 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
||||
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
|
||||
@classmethod
|
||||
def _get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None:
|
||||
try:
|
||||
hidden_size = config.get("hidden_size")
|
||||
match hidden_size:
|
||||
case 1280:
|
||||
return ClipVariantType.G
|
||||
case 768:
|
||||
return ClipVariantType.L
|
||||
case _:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
"""Model config for CLIP-G Embeddings."""
|
||||
@@ -1269,7 +1419,13 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_name(
|
||||
cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"}
|
||||
cls,
|
||||
mod.path / "config.json",
|
||||
{
|
||||
"CLIPModel",
|
||||
"CLIPTextModel",
|
||||
"CLIPTextModelWithProjection",
|
||||
},
|
||||
)
|
||||
|
||||
cls._validate_clip_g_variant(mod)
|
||||
@@ -1279,7 +1435,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
@classmethod
|
||||
def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None:
|
||||
config = _get_config_or_raise(cls, mod.path / "config.json")
|
||||
clip_variant = cls._get_clip_variant_type(config)
|
||||
clip_variant = _get_clip_variant_type_from_config(config)
|
||||
|
||||
if clip_variant is not ClipVariantType.G:
|
||||
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
|
||||
@@ -1301,7 +1457,13 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_name(
|
||||
cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"}
|
||||
cls,
|
||||
mod.path / "config.json",
|
||||
{
|
||||
"CLIPModel",
|
||||
"CLIPTextModel",
|
||||
"CLIPTextModelWithProjection",
|
||||
},
|
||||
)
|
||||
|
||||
cls._validate_clip_l_variant(mod)
|
||||
@@ -1311,7 +1473,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
|
||||
@classmethod
|
||||
def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None:
|
||||
config = _get_config_or_raise(cls, mod.path / "config.json")
|
||||
clip_variant = cls._get_clip_variant_type(config)
|
||||
clip_variant = _get_clip_variant_type_from_config(config)
|
||||
|
||||
if clip_variant is not ClipVariantType.L:
|
||||
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
|
||||
@@ -1330,7 +1492,13 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_name(cls, mod.path / "config.json", {"CLIPVisionModelWithProjection"})
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.path / "config.json",
|
||||
{
|
||||
"CLIPVisionModelWithProjection",
|
||||
},
|
||||
)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@@ -1354,7 +1522,13 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
|
||||
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_name(cls, mod.path / "config.json", {"T2IAdapter"})
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.path / "config.json",
|
||||
{
|
||||
"T2IAdapter",
|
||||
},
|
||||
)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
|
||||
@@ -1421,7 +1595,13 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_name(cls, mod.path / "config.json", {"SiglipModel"})
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.path / "config.json",
|
||||
{
|
||||
"SiglipModel",
|
||||
},
|
||||
)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@@ -1458,7 +1638,13 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
_validate_override_fields(cls, fields)
|
||||
|
||||
_validate_class_name(cls, mod.path / "config.json", {"LlavaOnevisionForConditionalGeneration"})
|
||||
_validate_class_name(
|
||||
cls,
|
||||
mod.path / "config.json",
|
||||
{
|
||||
"LlavaOnevisionForConditionalGeneration",
|
||||
},
|
||||
)
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user