mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(mm): wip port of main models to new api
This commit is contained in:
@@ -1126,12 +1126,29 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
)
|
||||
|
||||
base = fields.get("base") or cls._get_base_or_raise(mod)
|
||||
if base in {
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
}:
|
||||
variant = fields.get("variant") or cls._get_variant_or_raise(mod, base)
|
||||
prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod, base)
|
||||
upcast_attention = fields.get("upcast_attention") or cls._get_upcast_attention_or_raise(base, prediction_type)
|
||||
else:
|
||||
variant= None
|
||||
prediction_type = None
|
||||
upcast_attention = False
|
||||
|
||||
return cls(**fields, base=base)
|
||||
if base is BaseModelType.StableDiffusion3:
|
||||
submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod, base)
|
||||
else:
|
||||
submodels = None
|
||||
|
||||
return cls(**fields, base=base, variant=variant, prediction_type=prediction_type, upcast_attention=upcast_attention, submodels=submodels,)
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> MainDiffusers_SupportedBases:
|
||||
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
|
||||
# Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL).
|
||||
unet_config_path = mod.path / "unet" / "config.json"
|
||||
if unet_config_path.exists():
|
||||
with open(unet_config_path) as file:
|
||||
@@ -1172,7 +1189,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
}:
|
||||
raise ValueError(f"Attempted to get scheduler prediction type for non-UNet model base '{base}'")
|
||||
raise ValueError(f"Attempted to get scheduler prediction_type for non-UNet model base '{base}'")
|
||||
|
||||
scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json")
|
||||
|
||||
@@ -1185,7 +1202,7 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
case "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
case _:
|
||||
raise NotAMatch(cls, f"unrecognized scheduler prediction type {prediction_type}")
|
||||
raise NotAMatch(cls, f"unrecognized scheduler prediction_type {prediction_type}")
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_raise(cls, mod: ModelOnDisk, base: MainDiffusers_SupportedBases) -> ModelVariantType:
|
||||
@@ -1266,6 +1283,21 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin,
|
||||
return submodels
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_upcast_attention_or_raise(cls, base: MainDiffusers_SupportedBases, prediction_type: SchedulerPredictionType) -> bool:
|
||||
if base not in {
|
||||
BaseModelType.StableDiffusion1,
|
||||
BaseModelType.StableDiffusion2,
|
||||
BaseModelType.StableDiffusionXL,
|
||||
}:
|
||||
raise ValueError(f"Attempted to get upcast_attention flag for non-UNet model base '{base}'")
|
||||
|
||||
if base is BaseModelType.StableDiffusion2 and prediction_type is SchedulerPredictionType.VPrediction:
|
||||
# SD2 v-prediction models need upcast_attention to be True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
class IPAdapterConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user