mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 06:07:55 -05:00
Deprecate checkpoint as file, diffusers as directory terminology
This commit is contained in:
@@ -60,13 +60,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognize this combination of model type and format."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
class FSLayout(Enum):
|
||||
FILE = "file"
|
||||
DIRECTORY = "directory"
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
@@ -103,7 +107,7 @@ class ModelOnDisk:
|
||||
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
|
||||
self.path = path
|
||||
# TODO: Revisit checkpoint vs diffusers terminology
|
||||
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
|
||||
self.layout = FSLayout.DIRECTORY if path.is_dir() else FSLayout.FILE
|
||||
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||
self.name = path.stem
|
||||
else:
|
||||
@@ -115,18 +119,18 @@ class ModelOnDisk:
|
||||
return ModelHash(algorithm=self.hash_algo).hash(self.path)
|
||||
|
||||
def size(self) -> int:
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
if self.layout == FSLayout.FILE:
|
||||
return self.path.stat().st_size
|
||||
return sum(file.stat().st_size for file in self.path.rglob("*"))
|
||||
|
||||
def component_paths(self) -> set[Path]:
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
if self.layout == FSLayout.FILE:
|
||||
return {self.path}
|
||||
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
|
||||
return {f for f in self.path.rglob("*") if f.suffix in extensions}
|
||||
|
||||
def repo_variant(self) -> Optional[ModelRepoVariant]:
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
if self.layout == FSLayout.FILE:
|
||||
return None
|
||||
|
||||
weight_files = list(self.path.glob("**/*.safetensors"))
|
||||
@@ -581,7 +585,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.format_type == ModelFormat.Checkpoint:
|
||||
if mod.layout == FSLayout.FILE:
|
||||
return False
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
|
||||
Reference in New Issue
Block a user