feat(mm): port t5 to new API

This commit is contained in:
psychedelicious
2025-09-23 19:14:02 +10:00
parent 4b1450a4ff
commit 250163e6b7
3 changed files with 89 additions and 33 deletions

View File

@@ -406,16 +406,94 @@ class LoRAConfigBase(ABC, BaseModel):
class T5EncoderConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
base: Literal[BaseModelType.Any] = BaseModelType.Any
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
@classmethod
def get_config(cls, mod: ModelOnDisk) -> dict[str, Any]:
path = mod.path / "text_encoder_2" / "config.json"
with open(path, "r") as file:
return json.load(file)
class T5EncoderConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {}
class T5EncoderConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_t5_type_override = overrides.get("type") is ModelType.T5Encoder
is_t5_format_override = overrides.get("format") is ModelFormat.T5Encoder
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
if is_t5_type_override and is_t5_format_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
model_dir = mod.path / "text_encoder_2"
if not model_dir.exists():
return MatchCertainty.NEVER
try:
config = cls.get_config(mod)
is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel"
is_t5_format = (model_dir / "model.safetensors.index.json").exists()
if is_t5_encoder_model and is_t5_format:
return MatchCertainty.EXACT
except Exception:
pass
return MatchCertainty.NEVER
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, ModelConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
@classmethod
def matches(cls, mod: ModelOnDisk, **overrides) -> MatchCertainty:
is_t5_type_override = overrides.get("type") is ModelType.T5Encoder
is_bnb_format_override = overrides.get("format") is ModelFormat.BnbQuantizedLlmInt8b
if is_t5_type_override and is_bnb_format_override:
return MatchCertainty.OVERRIDE
if mod.path.is_file():
return MatchCertainty.NEVER
model_dir = mod.path / "text_encoder_2"
if not model_dir.exists():
return MatchCertainty.NEVER
try:
config = cls.get_config(mod)
is_t5_encoder_model = get_class_name_from_config(config) == "T5EncoderModel"
# Heuristic: look for the quantization in the name
files = model_dir.glob("*.safetensors")
filename_looks_like_bnb = any(x for x in files if "llm_int8" in x.as_posix())
if is_t5_encoder_model and filename_looks_like_bnb:
return MatchCertainty.EXACT
# Heuristic: Look for the presence of "SCB" in state dict keys (typically a suffix)
has_scb_key = mod.has_keys_ending_with("SCB")
if is_t5_encoder_model and has_scb_key:
return MatchCertainty.EXACT
except Exception:
pass
return MatchCertainty.NEVER
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.OMI] = ModelFormat.OMI

View File

@@ -879,30 +879,6 @@ class PipelineFolderProbe(FolderProbeBase):
return ModelVariantType.Normal
class T5EncoderFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
def get_format(self) -> ModelFormat:
path = self.model_path / "text_encoder_2"
if (path / "model.safetensors.index.json").exists():
return ModelFormat.T5Encoder
files = list(path.glob("*.safetensors"))
if len(files) == 0:
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: no .safetensors files found")
# shortcut: look for the quantization in the name
if any(x for x in files if "llm_int8" in x.as_posix()):
return ModelFormat.BnbQuantizedLlmInt8b
# more reliable path: probe contents for a 'SCB' key
ckpt = read_checkpoint_meta(files[0], scan=True)
if any("SCB" in x for x in ckpt.keys()):
return ModelFormat.BnbQuantizedLlmInt8b
raise InvalidModelConfigException(f"{self.model_path.as_posix()}: unknown model format")
class ONNXFolderProbe(PipelineFolderProbe):
def get_base_type(self) -> BaseModelType:
# Due to the way the installer is set up, the configuration file for safetensors
@@ -1036,7 +1012,6 @@ class T2IAdapterFolderProbe(FolderProbeBase):
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)

View File

@@ -129,18 +129,21 @@ class ModelOnDisk:
)
return path
def has_keys_exact(self, keys: set[str], path: Optional[Path] = None) -> bool:
def has_keys_exact(self, keys: str | set[str], path: Optional[Path] = None) -> bool:
_keys = {keys} if isinstance(keys, str) else keys
state_dict = self.load_state_dict(path)
return keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
return _keys.issubset({key for key in state_dict.keys() if isinstance(key, str)})
def has_keys_starting_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool:
def has_keys_starting_with(self, prefixes: str | set[str], path: Optional[Path] = None) -> bool:
_prefixes = {prefixes} if isinstance(prefixes, str) else prefixes
state_dict = self.load_state_dict(path)
return any(
any(key.startswith(prefix) for prefix in prefixes) for key in state_dict.keys() if isinstance(key, str)
any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str)
)
def has_keys_ending_with(self, prefixes: set[str], path: Optional[Path] = None) -> bool:
def has_keys_ending_with(self, suffixes: str | set[str], path: Optional[Path] = None) -> bool:
_suffixes = {suffixes} if isinstance(suffixes, str) else suffixes
state_dict = self.load_state_dict(path)
return any(
any(key.endswith(suffix) for suffix in prefixes) for key in state_dict.keys() if isinstance(key, str)
any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str)
)