mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
merge with main
This commit is contained in:
@@ -18,8 +18,7 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class VAELoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
|
||||
@@ -451,8 +451,16 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
# VAEs of all base types have the same structure, so we wimp out and
|
||||
# guess using the name.
|
||||
for regexp, basetype in [
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
(r"sd2", BaseModelType.StableDiffusion2),
|
||||
(r"vae", BaseModelType.StableDiffusion1),
|
||||
]:
|
||||
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
||||
return basetype
|
||||
raise InvalidModelConfigException("Cannot determine base type")
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
Reference in New Issue
Block a user