Remove hash algo

This commit is contained in:
Billy
2025-03-12 18:39:29 +11:00
parent be53b89203
commit f45400a275
2 changed files with 11 additions and 7 deletions

View File

@@ -649,7 +649,7 @@ class ModelInstallService(ModelInstallServiceBase):
fields = config.model_dump()
try:
return ModelConfigBase.classify(model_path, **fields)
return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
except InvalidModelConfigException:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore

View File

@@ -24,6 +24,7 @@ import logging
import time
from abc import ABC, abstractmethod
from enum import Enum
from functools import cached_property
from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union
@@ -203,13 +204,18 @@ class ControlAdapterDefaultSettings(BaseModel):
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
def __init__(self, path: Path):
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
@cached_property
def hash(self):
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def lazy_load_state_dict(self) -> dict[str, torch.Tensor]:
raise NotImplementedError()
@@ -282,7 +288,7 @@ class ModelConfigBase(ABC, BaseModel):
return concrete
@staticmethod
def classify(path: Path, **overrides):
def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
@@ -290,7 +296,7 @@ class ModelConfigBase(ABC, BaseModel):
"""
candidates = ModelConfigBase._USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
mod = ModelOnDisk(path)
mod = ModelOnDisk(model_path, hash_algo)
for config_cls in sorted_by_match_speed:
try:
@@ -335,9 +341,7 @@ class ModelConfigBase(ABC, BaseModel):
fields["source"] = fields.get("source") or fields["path"]
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["name"] = mod.name
default_hash_algo: HASHING_ALGORITHMS = "blake3_single"
fields["hash_algo"] = hash_algo = fields.get("hash_algo", default_hash_algo)
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(mod.path)
fields["hash"] = fields.get("hash") or mod.hash
fields.update(overrides)
return cls(**fields)