Small improvements (#7842)

## Summary

- Extend `ModelOnDisk` with caching, type hints, default args
- Fail early if there is an error classifying a config

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
This commit is contained in:
jazzhaiku
2025-03-28 12:21:41 +11:00
committed by GitHub
2 changed files with 42 additions and 19 deletions

View File

@@ -67,6 +67,11 @@ class InvalidModelConfigException(Exception):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class FSLayout(Enum):
FILE = "file"
DIRECTORY = "directory"
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
@@ -102,29 +107,31 @@ class ModelOnDisk:
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
# TODO: Revisit checkpoint vs diffusers terminology
self.layout = FSLayout.DIRECTORY if path.is_dir() else FSLayout.FILE
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
self._state_dict_cache = {}
def hash(self):
def hash(self) -> str:
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def size(self):
if self.format_type == ModelFormat.Checkpoint:
def size(self) -> int:
if self.layout == FSLayout.FILE:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self):
if self.format_type == ModelFormat.Checkpoint:
def component_paths(self) -> set[Path]:
if self.layout == FSLayout.FILE:
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def repo_variant(self):
if self.format_type == ModelFormat.Checkpoint:
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.layout == FSLayout.FILE:
return None
weight_files = list(self.path.glob("**/*.safetensors"))
@@ -140,14 +147,30 @@ class ModelOnDisk:
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
@staticmethod
def load_state_dict(path: Path):
def load_state_dict(self, path: Optional[Path] = None) -> Dict[str | int, Any]:
if path in self._state_dict_cache:
return self._state_dict_cache[path]
if not path:
components = list(self.component_paths())
match components:
case []:
raise ValueError("No weight files found for this model")
case [p]:
path = p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
assert isinstance(checkpoint, dict)
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
@@ -156,6 +179,7 @@ class ModelOnDisk:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
self._state_dict_cache[path] = state_dict
return state_dict
@@ -238,11 +262,13 @@ class ModelConfigBase(ABC, BaseModel):
for config_cls in sorted_by_match_speed:
try:
return config_cls.from_model_on_disk(mod, **overrides)
except InvalidModelConfigException:
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
if not config_cls.matches(mod):
continue
except Exception as e:
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
continue
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("No valid config found")
@@ -285,9 +311,6 @@ class ModelConfigBase(ABC, BaseModel):
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
if not cls.matches(mod):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
fields = cls.parse(mod)
cls.cast_overrides(overrides)
fields.update(overrides)
@@ -563,7 +586,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.format_type == ModelFormat.Checkpoint:
if mod.layout == FSLayout.FILE:
return False
config_path = mod.path / "config.json"

View File

@@ -71,7 +71,7 @@ def create_stripped_model(original_model_path: Path, stripped_model_path: Path)
print(f"Created clone of {original.name} at {stripped.path}")
for component_path in stripped.component_paths():
original_state_dict = ModelOnDisk.load_state_dict(component_path)
original_state_dict = stripped.load_state_dict(component_path)
stripped_state_dict = strip(original_state_dict) # type: ignore
with open(component_path, "w") as f:
json.dump(stripped_state_dict, f, indent=4)