mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge branch 'main' into feat/select-vram-in-config
This commit is contained in:
@@ -101,9 +101,9 @@ class ModelInstall(object):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||
model_manager: ModelManager = None,
|
||||
access_token: str = None,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
model_manager: Optional[ModelManager] = None,
|
||||
access_token: Optional[str] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||
|
||||
@@ -234,7 +234,7 @@ import textwrap
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from shutil import rmtree, move
|
||||
|
||||
import torch
|
||||
@@ -518,7 +518,7 @@ class ModelManager(object):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> dict:
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||
"""
|
||||
@@ -540,13 +540,15 @@ class ModelManager(object):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> dict:
|
||||
) -> Union[dict, None]:
|
||||
"""
|
||||
Returns a dict describing one installed model, using
|
||||
the combined format of the list_models() method.
|
||||
"""
|
||||
models = self.list_models(base_model, model_type, model_name)
|
||||
return models[0] if models else None
|
||||
if len(models) > 1:
|
||||
return models[0]
|
||||
return None
|
||||
|
||||
def list_models(
|
||||
self,
|
||||
@@ -560,7 +562,7 @@ class ModelManager(object):
|
||||
|
||||
model_keys = (
|
||||
[self.create_key(model_name, base_model, model_type)]
|
||||
if model_name
|
||||
if model_name and base_model and model_type
|
||||
else sorted(self.models, key=str.casefold)
|
||||
)
|
||||
models = []
|
||||
@@ -596,7 +598,7 @@ class ModelManager(object):
|
||||
Print a table of models and their descriptions. This needs to be redone
|
||||
"""
|
||||
# TODO: redo
|
||||
for model_type, model_dict in self.list_models().items():
|
||||
for model_dict in self.list_models():
|
||||
for model_name, model_info in model_dict.items():
|
||||
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
||||
print(line)
|
||||
@@ -699,8 +701,8 @@ class ModelManager(object):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str = None,
|
||||
new_base: BaseModelType = None,
|
||||
new_name: Optional[str] = None,
|
||||
new_base: Optional[BaseModelType] = None,
|
||||
):
|
||||
"""
|
||||
Rename or rebase a model.
|
||||
@@ -753,7 +755,7 @@ class ModelManager(object):
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
@@ -767,6 +769,10 @@ class ModelManager(object):
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
"""
|
||||
info = self.model_info(model_name, base_model, model_type)
|
||||
|
||||
if info is None:
|
||||
raise FileNotFoundError(f"model not found: {model_name}")
|
||||
|
||||
if info["model_format"] != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {model_name}")
|
||||
|
||||
@@ -836,7 +842,7 @@ class ModelManager(object):
|
||||
|
||||
return search_folder, found_models
|
||||
|
||||
def commit(self, conf_file: Path = None) -> None:
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
@@ -983,7 +989,7 @@ class ModelManager(object):
|
||||
# LS: hacky
|
||||
# Patch in the SD VAE from core so that it is available for use by the UI
|
||||
try:
|
||||
self.heuristic_import({self.resolve_model_path("core/convert/sd-vae-ft-mse")})
|
||||
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -1011,7 +1017,7 @@ class ModelManager(object):
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> Dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
|
||||
@@ -33,7 +33,7 @@ class ModelMerger(object):
|
||||
self,
|
||||
model_paths: List[Path],
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
@@ -73,7 +73,7 @@ class ModelMerger(object):
|
||||
base_model: Union[BaseModelType, str],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: MergeInterpolationMethod = None,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
**kwargs,
|
||||
@@ -122,7 +122,7 @@ class ModelMerger(object):
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
attributes = dict(
|
||||
path=str(dump_path),
|
||||
description=f"Merge of models {', '.join(model_names)}",
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from packaging import version
|
||||
import platform
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
@@ -30,7 +32,7 @@ def choose_precision(device: torch.device) -> str:
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
|
||||
return "float16"
|
||||
elif device.type == "mps":
|
||||
elif device.type == "mps" and version.parse(platform.mac_ver()[0]) < version.parse("14.0.0"):
|
||||
return "float16"
|
||||
return "float32"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user