mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
docs(mm): add todos
This commit is contained in:
@@ -1133,9 +1133,11 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
}:
|
||||
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
|
||||
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
|
||||
upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(base, prediction_type)
|
||||
upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(
|
||||
base, prediction_type
|
||||
)
|
||||
else:
|
||||
variant= None
|
||||
variant = None
|
||||
prediction_type = None
|
||||
upcast_attention = False
|
||||
|
||||
@@ -1144,7 +1146,16 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
else:
|
||||
submodels = None
|
||||
|
||||
return cls(**fields, base=base, variant=variant, prediction_type=prediction_type, upcast_attention=upcast_attention, submodels=submodels,)
|
||||
return cls(
|
||||
**fields,
|
||||
base=base,
|
||||
# TODO(psyche): figure out variant/prediction_type/upcast_attention
|
||||
variant=variant,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
# TODO(psyche): This is only for SD3 models - split up the config classes
|
||||
submodels=submodels,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
|
||||
@@ -1282,9 +1293,10 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
|
||||
return submodels
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_upcast_attention_or_raise(cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType) -> bool:
|
||||
def _get_upcast_attention_or_raise(
|
||||
cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType
|
||||
) -> bool:
|
||||
if base not in {
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
@@ -1298,6 +1310,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class IPAdapterConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user