Compare commits

...

6 Commits

Author SHA1 Message Date
brandonrising
653a8f67a4 Run ruff 2024-03-26 15:10:28 -04:00
brandonrising
90f32d17ba Use correct format type on model 2024-03-26 15:10:28 -04:00
Brandon Rising
a7b8bbc7c6 Use model_path defined in probe run 2024-03-26 15:10:28 -04:00
Brandon Rising
befcf779dc Fix paths in probe 2024-03-26 15:10:28 -04:00
Brandon Rising
18098cc4b9 For loras/embeddings treat folders the same as single files 2024-03-26 15:10:28 -04:00
Brandon Rising
bbecb99eb4 Allow Embedding and Lora Folder Models to have Different Names 2024-03-26 15:10:28 -04:00

View File

@@ -1,4 +1,5 @@
import json
import os
import re
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
@@ -145,7 +146,8 @@ class ModelProbe(object):
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
probe = probe_class(model_path)
model_path = probe.model_path
format_type = probe.get_format()
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["source"] = fields.get("source") or model_path.as_posix()
fields["key"] = fields.get("key", uuid_string())
@@ -159,7 +161,8 @@ class ModelProbe(object):
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["format"] = fields.get("format") or format_type
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
fields["default_settings"] = fields.get("default_settings")
@@ -643,17 +646,19 @@ class VaeFolderProbe(FolderProbeBase):
return name
class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self) -> ModelFormat:
return ModelFormat.EmbeddingFolder
def get_base_type(self) -> BaseModelType:
path = self.model_path / "learned_embeds.bin"
if not path.exists():
class TextualInversionFolderProbe(TextualInversionCheckpointProbe):
def __init__(self, model_path: Path):
files = os.scandir(model_path)
files = [
Path(f.path)
for f in files
if f.is_file() and f.name.endswith((".ckpt", ".pt", ".pth", ".bin", ".safetensors"))
]
if len(files) != 1:
raise InvalidModelConfigException(
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file"
f"Unable to determine base type for {model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
)
return TextualInversionCheckpointProbe(path).get_base_type()
super().__init__(files.pop())
class ONNXFolderProbe(PipelineFolderProbe):
@@ -699,17 +704,16 @@ class ControlNetFolderProbe(FolderProbeBase):
return base_model
class LoRAFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
model_file = None
for suffix in ["safetensors", "bin"]:
base_file = self.model_path / f"pytorch_lora_weights.{suffix}"
if base_file.exists():
model_file = base_file
break
if not model_file:
raise InvalidModelConfigException("Unknown LoRA format encountered")
return LoRACheckpointProbe(model_file).get_base_type()
class LoRAFolderProbe(LoRACheckpointProbe):
def __init__(self, model_path: Path):
files = os.scandir(model_path)
files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".bin", ".safetensors"))]
if len(files) != 1:
raise InvalidModelConfigException(
f"Unable to determine base type for lora {model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
)
model_file = files.pop()
super().__init__(model_file)
class IPAdapterFolderProbe(FolderProbeBase):