docs(mm): add todos

This commit is contained in:
psychedelicious
2025-09-26 18:35:00 +10:00
parent 0192caa90f
commit 54e3c3e209

View File

@@ -1133,9 +1133,11 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
}:
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(base, prediction_type)
upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(
base, prediction_type
)
else:
variant= None
variant = None
prediction_type = None
upcast_attention = False
@@ -1144,7 +1146,16 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
else:
submodels = None
return cls(**fields, base=base, variant=variant, prediction_type=prediction_type, upcast_attention=upcast_attention, submodels=submodels,)
return cls(
**fields,
base=base,
# TODO(psyche): figure out variant/prediction_type/upcast_attention
variant=variant,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
# TODO(psyche): This is only for SD3 models - split up the config classes
submodels=submodels,
)
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
@@ -1282,9 +1293,10 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
return submodels
@classmethod
def _get_upcast_attention_or_raise(cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType) -> bool:
def _get_upcast_attention_or_raise(
cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType
) -> bool:
if base not in {
BaseModelType.StableDiffusion1,
BaseModelType.StableDiffusion2,
@@ -1298,6 +1310,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
return False
class IPAdapterConfigBase(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)