feat(mm): wip port of main models to new api

This commit is contained in:
psychedelicious
2025-09-25 23:02:44 +10:00
parent f9686b38fa
commit 6f5720904a

View File

@@ -1094,7 +1094,6 @@ MainDiffusers_SupportedBases: TypeAlias = Literal[
BaseModelType.StableDiffusion3,
BaseModelType.StableDiffusionXL,
BaseModelType.StableDiffusionXLRefiner,
BaseModelType.Flux,
BaseModelType.CogView4,
]
@@ -1104,6 +1103,157 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
base: MainDiffusers_SupportedBases = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self:
_validate_is_dir(cls, mod)
_validate_override_fields(cls, fields)
_validate_class_name(
cls,
mod.path / "config.json",
{
"StableDiffusionPipeline",
"StableDiffusionInpaintPipeline",
"StableDiffusionXLPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusion3Pipeline",
"LatentConsistencyModelPipeline",
"SD3Transformer2DModel",
"CogView4Pipeline",
},
)
base = fields.get("base") or cls._get_base_or_raise(mod)
return cls(**fields, base=base)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
unet_config_path = mod.path / "unet" / "config.json"
if unet_config_path.exists():
with open(unet_config_path) as file:
unet_conf = json.load(file)
cross_attention_dim = unet_conf.get("cross_attention_dim")
match cross_attention_dim:
case 768:
return BaseModelType.StableDiffusion1
case 1024:
return BaseModelType.StableDiffusion2
case 1280:
return BaseModelType.StableDiffusionXLRefiner
case 2048:
return BaseModelType.StableDiffusionXL
case _:
raise NotAMatch(cls, f"unrecognized cross_attention_dim {cross_attention_dim}")
# Handle pipelines with a transformer (i.e. SD3).
transformer_config_path = mod.path / "transformer" / "config.json"
if transformer_config_path.exists():
class_name = _get_class_name_from_config(cls, transformer_config_path)
match class_name:
case "SD3Transformer2DModel":
return BaseModelType.StableDiffusion3
case "CogView4Transformer2DModel":
return BaseModelType.CogView4
case _:
raise NotAMatch(cls, f"unrecognized transformer class name {class_name}")
raise NotAMatch(cls, "unable to determine base type")
@classmethod
def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> SchedulerPredictionType:
if base not in {
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
}:
raise ValueError(f"Attempted to get scheduler prediction type for non-UNet model base '{base}'")
scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json")
# TODO(psyche): Is epsilon the right default or should we raise if it's not present?
prediction_type = scheduler_conf.get("prediction_type", "epsilon")
match prediction_type:
case "v_prediction":
return SchedulerPredictionType.VPrediction
case "epsilon":
return SchedulerPredictionType.Epsilon
case _:
raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> ModelVariantType:
if base not in {
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
BaseModelType.StableDiffusionXL,
}:
raise ValueError(f"Attempted to get variant for model base '{base}' but it does not have variants")
unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json")
in_channels = unet_config.get("in_channels")
match in_channels:
case 4:
return ModelVariantType.Normal
case 5:
if base is not BaseModelType.StableDiffusion2:
raise NotAMatch(cls, "in_channels=5 is only valid for Stable Diffusion 2 models")
return ModelVariantType.Depth
case 9:
return ModelVariantType.Inpaint
case _:
raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels}")
@classmethod
def _get_submodels_or_raise(cls, mod: ModelOnDisk, base: BaseModelType) -> dict[SubModelType, SubmodelDefinition]:
if base is not BaseModelType.StableDiffusion3:
raise ValueError(f"Attempted to get submodels for non-SD3 model base '{base}'")
# Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json
config = _get_config_or_raise(cls, mod.path / "model_index.json")
submodels: dict[SubModelType, SubmodelDefinition] = {}
for key, value in config.items():
# Anything that starts with an underscore is top-level metadata, not a submodel
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
continue
# The key is something like "transformer" and is a submodel - it will be in a dir of the same name.
# The value value is something like ["diffusers", "SD3Transformer2DModel"]
_library_name, class_name = value
match class_name:
case "CLIPTextModelWithProjection":
model_type = ModelType.CLIPEmbed
path_or_prefix = (mod.path / key).resolve().as_posix()
# We need to read the config to determine the variant of the CLIP model.
clip_embed_config = _get_config_or_raise(cls, mod.path / key / "config.json")
variant = _get_clip_variant_type_from_config(clip_embed_config)
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=path_or_prefix,
model_type=model_type,
variant=variant,
)
case "SD3Transformer2DModel":
model_type = ModelType.Main
path_or_prefix = (mod.path / key).resolve().as_posix()
variant = None
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=path_or_prefix,
model_type=model_type,
variant=variant,
)
case _:
pass
return submodels
class IPAdapterConfigBase(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
@@ -1231,6 +1381,20 @@ class IPAdapterCheckpointConfig(IPAdapterConfigBase, ModelConfigBase):
raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}")
def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None:
try:
hidden_size = config.get("hidden_size")
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return None
except Exception:
return None
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""
@@ -1238,20 +1402,6 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
@classmethod
def _get_clip_variant_type(cls, config: dict[str, Any]) -> ClipVariantType | None:
try:
hidden_size = config.get("hidden_size")
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return None
except Exception:
return None
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
"""Model config for CLIP-G Embeddings."""
@@ -1269,7 +1419,13 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
_validate_override_fields(cls, fields)
_validate_class_name(
cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"}
cls,
mod.path / "config.json",
{
"CLIPModel",
"CLIPTextModel",
"CLIPTextModelWithProjection",
},
)
cls._validate_clip_g_variant(mod)
@@ -1279,7 +1435,7 @@ class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
@classmethod
def _validate_clip_g_variant(cls, mod: ModelOnDisk) -> None:
config = _get_config_or_raise(cls, mod.path / "config.json")
clip_variant = cls._get_clip_variant_type(config)
clip_variant = _get_clip_variant_type_from_config(config)
if clip_variant is not ClipVariantType.G:
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
@@ -1301,7 +1457,13 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
_validate_override_fields(cls, fields)
_validate_class_name(
cls, mod.path / "config.json", {"CLIPModel", "CLIPTextModel", "CLIPTextModelWithProjection"}
cls,
mod.path / "config.json",
{
"CLIPModel",
"CLIPTextModel",
"CLIPTextModelWithProjection",
},
)
cls._validate_clip_l_variant(mod)
@@ -1311,7 +1473,7 @@ class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, ModelConfigBase):
@classmethod
def _validate_clip_l_variant(cls, mod: ModelOnDisk) -> None:
config = _get_config_or_raise(cls, mod.path / "config.json")
clip_variant = cls._get_clip_variant_type(config)
clip_variant = _get_clip_variant_type_from_config(config)
if clip_variant is not ClipVariantType.L:
raise NotAMatch(cls, "model does not match CLIP-G heuristics")
@@ -1330,7 +1492,13 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase, ModelConfigBase):
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.path / "config.json", {"CLIPVisionModelWithProjection"})
_validate_class_name(
cls,
mod.path / "config.json",
{
"CLIPVisionModelWithProjection",
},
)
return cls(**fields)
@@ -1354,7 +1522,13 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, ModelConfi
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.path / "config.json", {"T2IAdapter"})
_validate_class_name(
cls,
mod.path / "config.json",
{
"T2IAdapter",
},
)
base = fields.get("base") or cls._get_base_or_raise(mod)
@@ -1421,7 +1595,13 @@ class SigLIPConfig(DiffusersConfigBase, ModelConfigBase):
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.path / "config.json", {"SiglipModel"})
_validate_class_name(
cls,
mod.path / "config.json",
{
"SiglipModel",
},
)
return cls(**fields)
@@ -1458,7 +1638,13 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
_validate_override_fields(cls, fields)
_validate_class_name(cls, mod.path / "config.json", {"LlavaOnevisionForConditionalGeneration"})
_validate_class_name(
cls,
mod.path / "config.json",
{
"LlavaOnevisionForConditionalGeneration",
},
)
return cls(**fields)