diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 822856e520..b109fa6297 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -406,16 +406,94 @@ class LoRAConfigBase(ABC, BaseModel): class T5EncoderConfigBase(ABC, BaseModel): """Base class for diffusers-style models.""" + base: Literal[BaseModelType.Any] = BaseModelType.Any type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder + @classmethod + def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]: + path = mod.path / "text_encoder_2" / "config.json" + with open(path, "r") as file: + return json.load(file) -class T5EncoderConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase): + @classmethod + def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: + return {} + + +class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase): format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder + @classmethod + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_t5_type_override = overrides.get("type") is ModelType.T5Encoder + is_t5_format_override = overrides.get("format") is ModelFormat.T5Encoder -class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase): + if is_t5_type_override and is_t5_format_override: + return MatchCertainty.OVERRIDE + + if mod.path.is_file(): + return MatchCertainty.NEVER + + model_dir = mod.path / "text_encoder_2" + + if not model_dir.exists(): + return MatchCertainty.NEVER + + try: + config = cls.get_config(mod) + + is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel" + is_t5_format = (model_dir / "model.safetensors.index.json").exists() + + if is_t5_encoder_model and is_t5_format: + return MatchCertainty.EXACT + except Exception: + pass + + return MatchCertainty.NEVER + + +class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase): format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b + @classmethod + def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty: + is_t5_type_override = overrides.get("type") is ModelType.T5Encoder + is_bnb_format_override = overrides.get("format") is ModelFormat.BnbQuantizedLlmInt8b + + if is_t5_type_override and is_bnb_format_override: + return MatchCertainty.OVERRIDE + + if mod.path.is_file(): + return MatchCertainty.NEVER + + model_dir = mod.path / "text_encoder_2" + + if not model_dir.exists(): + return MatchCertainty.NEVER + + try: + config = cls.get_config(mod) + + is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel" + + # Heuristic: look for the quantization in the name + files = model_dir.glob("*.safetensors") + filename_looks_like_bnb = any(x for x in files if "llm_int8" in x.as_posix()) + + if is_t5_encoder_model and filename_looks_like_bnb: + return MatchCertainty.EXACT + + # Heuristic: Look for the presence of "SCB" in state dict keys (typically a suffix) + has_scb_key = mod.has_keys_ending_with("SCB") + + if is_t5_encoder_model and has_scb_key: + return MatchCertainty.EXACT + except Exception: + pass + + return MatchCertainty.NEVER + class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): format: Literal[ModelFormat.OMI] = ModelFormat.OMI diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index 5955c8af2c..85a39fd25e 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -879,30 +879,6 @@ class PipelineFolderProbe(FolderProbeBase): return ModelVariantType.Normal -class T5EncoderFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - def get_format(self) -> ModelFormat: - path = self.model_path / "text_encoder_2" - if (path / "model.safetensors.index.json").exists(): - return ModelFormat.T5Encoder - files = list(path.glob("*.safetensors")) - if len(files) == 0: - raise InvalidModelConfigException(f"{self.model_path.as_posix()}: no .safetensors files found") - - # shortcut: look for the quantization in the name - if any(x for x in files if "llm_int8" in x.as_posix()): - return ModelFormat.BnbQuantizedLlmInt8b - - # more reliable path: probe contents for a 'SCB' key - ckpt = read_checkpoint_meta(files[0], scan=True) - if any("SCB" in x for x in ckpt.keys()): - return ModelFormat.BnbQuantizedLlmInt8b - - raise InvalidModelConfigException(f"{self.model_path.as_posix()}: unknown model format") - - class ONNXFolderProbe(PipelineFolderProbe): def get_base_type(self) -> BaseModelType: # Due to the way the installer is set up, the configuration file for safetensors @@ -1036,7 +1012,6 @@ class T2IAdapterFolderProbe(FolderProbeBase): ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) diff --git a/invokeai/backend/model_manager/model_on_disk.py b/invokeai/backend/model_manager/model_on_disk.py index 9de78e53b3..1f7625e09c 100644 --- a/invokeai/backend/model_manager/model_on_disk.py +++ b/invokeai/backend/model_manager/model_on_disk.py @@ -129,18 +129,21 @@ class ModelOnDisk: ) return path - def has_keys_exact(self, keys: set[str], path: Optional[Path] = None) -> bool: + def has_keys_exact(self, keys: str | set[str], path: Optional[Path] = None) -> bool: + _keys = {keys} if isinstance(keys, str) else keys state_dict = self.load_state_dict(path) - return keys.issubset({key for key in state_dict.keys() if isinstance(key, str)}) + return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)}) - def has_keys_starting_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool: + def has_keys_starting_with(self, prefixes: str | set[str], path: Optional[Path] = None) -> bool: + _prefixes = {prefixes} if isinstance(prefixes, str) else prefixes state_dict = self.load_state_dict(path) return any( - any(key.startswith(prefix) for prefix in prefixes) for key in state_dict.keys() if isinstance(key, str) + any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str) ) - def has_keys_ending_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool: + def has_keys_ending_with(self, suffixes: str | set[str], path: Optional[Path] = None) -> bool: + _suffixes = {suffixes} if isinstance(suffixes, str) else suffixes state_dict = self.load_state_dict(path) return any( - any(key.endswith(suffix) for suffix in prefixes) for key in state_dict.keys() if isinstance(key, str) + any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str) )