Deprecate checkpoint as file, diffusers as directory terminology

This commit is contained in:
Billy
2025-03-27 08:10:12 +11:00
parent 60b5aef16a
commit 82dd2d508f

View File

@@ -60,13 +60,17 @@ logger = logging.getLogger(__name__)
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognize this combination of model type and format."""
pass
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
class FSLayout(Enum):
FILE = "file"
DIRECTORY = "directory"
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
@@ -103,7 +107,7 @@ class ModelOnDisk:
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
# TODO: Revisit checkpoint vs diffusers terminology
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
self.layout = FSLayout.DIRECTORY if path.is_dir() else FSLayout.FILE
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
@@ -115,18 +119,18 @@ class ModelOnDisk:
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def size(self) -> int:
if self.format_type == ModelFormat.Checkpoint:
if self.layout == FSLayout.FILE:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self) -> set[Path]:
if self.format_type == ModelFormat.Checkpoint:
if self.layout == FSLayout.FILE:
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.format_type == ModelFormat.Checkpoint:
if self.layout == FSLayout.FILE:
return None
weight_files = list(self.path.glob("**/*.safetensors"))
@@ -581,7 +585,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.format_type == ModelFormat.Checkpoint:
if mod.layout == FSLayout.FILE:
return False
config_path = mod.path / "config.json"