mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
6 Commits
controlnet
...
allow-embe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
653a8f67a4 | ||
|
|
90f32d17ba | ||
|
|
a7b8bbc7c6 | ||
|
|
befcf779dc | ||
|
|
18098cc4b9 | ||
|
|
bbecb99eb4 |
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
@@ -145,7 +146,8 @@ class ModelProbe(object):
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
probe = probe_class(model_path)
|
||||
|
||||
model_path = probe.model_path
|
||||
format_type = probe.get_format()
|
||||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
||||
fields["source"] = fields.get("source") or model_path.as_posix()
|
||||
fields["key"] = fields.get("key", uuid_string())
|
||||
@@ -159,7 +161,8 @@ class ModelProbe(object):
|
||||
fields["description"] = (
|
||||
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["format"] = fields.get("format") or format_type
|
||||
|
||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||
|
||||
fields["default_settings"] = fields.get("default_settings")
|
||||
@@ -643,17 +646,19 @@ class VaeFolderProbe(FolderProbeBase):
|
||||
return name
|
||||
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat.EmbeddingFolder
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
path = self.model_path / "learned_embeds.bin"
|
||||
if not path.exists():
|
||||
class TextualInversionFolderProbe(TextualInversionCheckpointProbe):
|
||||
def __init__(self, model_path: Path):
|
||||
files = os.scandir(model_path)
|
||||
files = [
|
||||
Path(f.path)
|
||||
for f in files
|
||||
if f.is_file() and f.name.endswith((".ckpt", ".pt", ".pth", ".bin", ".safetensors"))
|
||||
]
|
||||
if len(files) != 1:
|
||||
raise InvalidModelConfigException(
|
||||
f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file"
|
||||
f"Unable to determine base type for {model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
|
||||
)
|
||||
return TextualInversionCheckpointProbe(path).get_base_type()
|
||||
super().__init__(files.pop())
|
||||
|
||||
|
||||
class ONNXFolderProbe(PipelineFolderProbe):
|
||||
@@ -699,17 +704,16 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
return base_model
|
||||
|
||||
|
||||
class LoRAFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = None
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
base_file = self.model_path / f"pytorch_lora_weights.{suffix}"
|
||||
if base_file.exists():
|
||||
model_file = base_file
|
||||
break
|
||||
if not model_file:
|
||||
raise InvalidModelConfigException("Unknown LoRA format encountered")
|
||||
return LoRACheckpointProbe(model_file).get_base_type()
|
||||
class LoRAFolderProbe(LoRACheckpointProbe):
|
||||
def __init__(self, model_path: Path):
|
||||
files = os.scandir(model_path)
|
||||
files = [Path(f.path) for f in files if f.is_file() and f.name.endswith((".bin", ".safetensors"))]
|
||||
if len(files) != 1:
|
||||
raise InvalidModelConfigException(
|
||||
f"Unable to determine base type for lora {model_path}: expected exactly one valid model file, found {[f.name for f in files]}."
|
||||
)
|
||||
model_file = files.pop()
|
||||
super().__init__(model_file)
|
||||
|
||||
|
||||
class IPAdapterFolderProbe(FolderProbeBase):
|
||||
|
||||
Reference in New Issue
Block a user