mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-15 03:25:20 -05:00
61 lines
2.1 KiB
Python
61 lines
2.1 KiB
Python
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from invokeai.backend.model_manager.config import (
|
|
AnyModelConfig,
|
|
CheckpointConfigBase,
|
|
DiffusersConfigBase,
|
|
)
|
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
|
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
|
from invokeai.backend.model_manager.taxonomy import (
|
|
AnyModel,
|
|
BaseModelType,
|
|
ModelFormat,
|
|
ModelType,
|
|
SubModelType,
|
|
)
|
|
|
|
|
|
@ModelLoaderRegistry.register(base=BaseModelType.CogView4, type=ModelType.Main, format=ModelFormat.Diffusers)
|
|
class CogView4DiffusersModel(GenericDiffusersLoader):
|
|
"""Class to load CogView4 main models."""
|
|
|
|
def _load_model(
|
|
self,
|
|
config: AnyModelConfig,
|
|
submodel_type: Optional[SubModelType] = None,
|
|
) -> AnyModel:
|
|
if isinstance(config, CheckpointConfigBase):
|
|
raise NotImplementedError("CheckpointConfigBase is not implemented for CogView4 models.")
|
|
|
|
if submodel_type is None:
|
|
raise Exception("A submodel type must be provided when loading main pipelines.")
|
|
|
|
model_path = Path(config.path)
|
|
load_class = self.get_hf_load_class(model_path, submodel_type)
|
|
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
|
|
variant = repo_variant.value if repo_variant else None
|
|
model_path = model_path / submodel_type.value
|
|
|
|
# We force bfloat16 for CogView4 models. It produces black images with float16. I haven't tracked down
|
|
# specifically which model(s) is/are responsible.
|
|
dtype = torch.bfloat16
|
|
try:
|
|
result: AnyModel = load_class.from_pretrained(
|
|
model_path,
|
|
torch_dtype=dtype,
|
|
variant=variant,
|
|
)
|
|
except OSError as e:
|
|
if variant and "no file named" in str(
|
|
e
|
|
): # try without the variant, just in case user's preferences changed
|
|
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
|
|
else:
|
|
raise e
|
|
|
|
return result
|