Compare commits

..

28 Commits

Author SHA1 Message Date
psychedelicious
46405c5f09 fix(mm): fix update error
whoops!
2023-08-08 12:25:45 +10:00
Kevin Turner
4367061b19 fix(ModelManager): fix overridden VAE with relative path (#4059) 2023-08-07 12:57:32 -07:00
Kevin Turner
f272a44feb Merge branch 'main' into refactor/model_manager_instantiate 2023-08-07 10:59:28 -07:00
Jonathan
ae17d01e1d Fix hue adjustment (#4182)
* Fix hue adjustment

Hue adjustment wasn't working correctly because color channels got swapped. This has now been fixed and we're using PIL rather than cv2 to do the RGBA->HSV->RGBA conversion. The range of hue adjustment is also the more typical 0..360 degrees.
2023-08-06 23:23:51 +00:00
Kevin Turner
5bfd6cb66f Merge remote-tracking branch 'origin/main' into refactor/model_manager_instantiate
# Conflicts:
#	invokeai/backend/model_management/model_manager.py
2023-08-05 22:02:28 -07:00
Kevin Turner
7f4c387080 test(model_management): factor out name strings 2023-08-05 15:46:46 -07:00
Kevin Turner
80876bbbd1 Merge remote-tracking branch 'origin/refactor/model_manager_instantiate' into refactor/model_manager_instantiate 2023-08-05 15:25:05 -07:00
Kevin Turner
7a4ff4c089 Merge branch 'main' into refactor/model_manager_instantiate 2023-08-05 15:23:38 -07:00
Kevin Turner
44bf308192 test(model_management): add a couple tests for _get_model_path 2023-08-05 15:22:23 -07:00
Kevin Turner
65ed224bfc Merge branch 'main' into refactor/model_manager_instantiate 2023-08-04 21:34:38 -07:00
Kevin Turner
b10cf20eb1 Merge branch 'main' into refactor/model_manager_instantiate
# Conflicts:
#	invokeai/backend/model_management/model_manager.py
2023-08-04 18:28:18 -07:00
Kevin Turner
91ebf9f76e Merge branch 'main' into refactor/model_manager_instantiate 2023-08-02 19:01:21 -07:00
Kevin Turner
02d2cc758d Merge branch 'main' into refactor/model_manager_instantiate 2023-08-02 17:11:23 -07:00
Kevin Turner
1f9e984b0d Merge branch 'main' into refactor/model_manager_instantiate 2023-08-01 16:49:39 -07:00
Kevin Turner
5998509888 Merge branch 'main' into refactor/model_manager_instantiate 2023-08-01 11:09:43 -07:00
Kevin Turner
bacdf985f1 doc(model_manager): docstrings 2023-07-31 09:16:32 -07:00
Kevin Turner
e3519052ae Merge remote-tracking branch 'origin/main' into refactor/model_manager_instantiate 2023-07-31 08:46:09 -07:00
Kevin Turner
adfd1e52f4 refactor(model_manager): avoid copy/paste logic 2023-07-30 11:53:12 -07:00
Kevin Turner
0e48c98330 Merge remote-tracking branch 'origin/main' into refactor/model_manager_instantiate
# Conflicts:
#	invokeai/backend/model_management/model_manager.py
2023-07-30 11:33:13 -07:00
Kevin Turner
ff1c40747e lint: formatting 2023-07-29 20:02:31 -07:00
Kevin Turner
dbfd1bcb5e Merge branch 'main' into refactor/model_manager_instantiate 2023-07-29 19:53:21 -07:00
Kevin Turner
ccceb32a85 lint: formatting 2023-07-29 11:50:04 -07:00
Kevin Turner
21617e60e1 Merge remote-tracking branch 'origin/main' into refactor/model_manager_instantiate 2023-07-29 08:21:26 -07:00
Kevin Turner
86b8b69e88 internal(ModelManager): add instantiate method 2023-07-28 22:30:25 -07:00
Kevin Turner
bc9a5038fd refactor(ModelManager): factor out get_model_path 2023-07-28 22:29:36 -07:00
Kevin Turner
b163ae6a4d refactor(ModelManager): factor out get_model_config 2023-07-28 21:30:20 -07:00
Kevin Turner
dca685ac25 refactor(ModelManager): refactor rescan-on-miss to exists() method 2023-07-28 21:11:00 -07:00
Kevin Turner
e70bedba7d refactor(ModelManager): factor out _get_implementation method 2023-07-28 21:03:27 -07:00
9 changed files with 160 additions and 64 deletions

View File

@@ -661,27 +661,23 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
hue: int = Field(default=0, description="The degrees by which to rotate the hue")
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
hsv_image = numpy.array(pil_image.convert("HSV"))
# Adjust the hue
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + self.hue) % 180
# Convert hue from 0..360 to 0..256
hue = int(256 * ((self.hue % 360) / 360))
# Convert image back to BGR color space
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
# Increment each hue and wrap around at 255
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,

View File

@@ -180,10 +180,6 @@ class TextToLatentsInvocation(BaseInvocation):
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
# for ancestral and sde schedulers
generator = torch.Generator(device=unet.device)
generator.seed()
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
text_embeddings=c,
@@ -202,7 +198,7 @@ class TextToLatentsInvocation(BaseInvocation):
# for ddim scheduler
eta=0.0, # ddim_eta
# for ancestral and sde schedulers
generator=generator,
generator=torch.Generator(device=unet.device).manual_seed(0),
)
return conditioning_data

View File

@@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR.
"""
from __future__ import annotations
import os
import hashlib
import os
import textwrap
import yaml
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree, move
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
import torch
import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel, Field
import invokeai.backend.util.logging as logger
@@ -259,6 +259,7 @@ from .models import (
ModelNotFoundException,
InvalidModelException,
DuplicateModelException,
ModelBase,
)
# We are only starting to number the config file with release 3.
@@ -361,7 +362,7 @@ class ModelManager(object):
if model_key.startswith("_"):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
@@ -381,18 +382,24 @@ class ModelManager(object):
# causing otherwise unreferenced models to be removed from memory
self._read_models()
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
"""
Given a model name, returns True if it is a valid
identifier.
Given a model name, returns True if it is a valid identifier.
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param rescan: if True, scan_models_directory
"""
model_key = self.create_key(model_name, base_model, model_type)
return model_key in self.models
exists = model_key in self.models
# if model not found try to find it (maybe file just pasted)
if rescan and not exists:
self.scan_models_directory(base_model=base_model, model_type=model_type)
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
return exists
@classmethod
def create_key(
@@ -443,39 +450,32 @@ class ModelManager(object):
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param submode_typel: an ModelType enum indicating the portion of
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_class = MODEL_CLASSES[base_model][model_type]
model_key = self.create_key(model_name, base_model, model_type)
# if model not found try to find it (maybe file just pasted)
if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models:
raise ModelNotFoundException(f"Model not found - {model_key}")
if not self.model_exists(model_name, base_model, model_type, rescan=True):
raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key]
model_path = self.resolve_model_path(model_config.path)
model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
model_type = submodel_type
submodel_type = None
model_class = self._get_implementation(base_model, model_type)
if not model_path.exists():
if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound
raise Exception(f'Files for model "{model_key}" not found')
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
else:
self.models.pop(model_key, None)
raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override
# TODO:
if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type)
if override_path:
model_path = self.resolve_path(override_path)
model_type = submodel_type
submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type]
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
# TODO: path
# TODO: is it accurate to use path as id
@@ -513,6 +513,55 @@ class ModelManager(object):
_cache=self.cache,
)
def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
) -> (Path, bool):
"""Extract a model's filesystem path from its config.
:return: The fully qualified Path of the module (or submodule).
"""
model_path = model_config.path
is_submodel_override = False
# Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None:
model_path = getattr(model_config, submodel_type)
is_submodel_override = True
model_path = self.resolve_model_path(model_path)
return model_path, is_submodel_override
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
"""Get a model's config object."""
model_key = self.create_key(model_name, base_model, model_type)
try:
model_config = self.models[model_key]
except KeyError:
raise ModelNotFoundException(f"Model not found - {model_key}")
return model_config
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type]
return model_class
def _instantiate(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> ModelBase:
"""Make a new instance of this model, without loading it."""
model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
# FIXME: do non-overriden submodels get the right class?
constructor = self._get_implementation(base_model, model_type)
instance = constructor(model_path, base_model, model_type)
return instance
def model_info(
self,
model_name: str,
@@ -546,7 +595,7 @@ class ModelManager(object):
the combined format of the list_models() method.
"""
models = self.list_models(base_model, model_type, model_name)
if len(models) > 1:
if len(models) > 0:
return models[0]
return None
@@ -660,7 +709,7 @@ class ModelManager(object):
if path := model_attributes.get("path"):
model_attributes["path"] = str(self.relative_model_path(Path(path)))
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type)
@@ -851,7 +900,7 @@ class ModelManager(object):
for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
model_class = self._get_implementation(base_model, model_type)
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
@@ -909,7 +958,7 @@ class ModelManager(object):
model_path = self.resolve_model_path(model_config.path).absolute()
if not model_path.exists():
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
model_class = self._get_implementation(cur_base_model, cur_model_type)
if model_class.save_to_config:
model_config.error = ModelError.NotFound
self.models.pop(model_key, None)
@@ -925,7 +974,7 @@ class ModelManager(object):
for cur_model_type in ModelType:
if model_type is not None and cur_model_type != model_type:
continue
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
model_class = self._get_implementation(cur_base_model, cur_model_type)
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
if not models_dir.exists():

View File

@@ -1,9 +1,14 @@
import os
import torch
import safetensors
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from typing import Optional
import safetensors
import torch
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from .base import (
ModelBase,
ModelConfigBase,
@@ -18,9 +23,6 @@ from .base import (
InvalidModelException,
ModelNotFoundException,
)
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
@@ -80,7 +82,7 @@ class VaeModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
raise ModelNotFoundException(f"Does not exist as local file: {path}")
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):

View File

@@ -100,7 +100,7 @@ dependencies = [
"dev" = [
"pudb",
]
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
"xformers" = [
"xformers~=0.0.19; sys_platform!='darwin'",
"triton; sys_platform=='linux'",

View File

@@ -0,0 +1,38 @@
from pathlib import Path
import pytest
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
@pytest.fixture
def model_manager(datadir) -> ModelManager:
InvokeAIAppConfig.get_config(root=datadir)
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
def test_get_model_names(model_manager: ModelManager):
names = model_manager.model_names()
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2])
top_model_path, is_override = model_manager._get_model_path(model_config)
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
assert top_model_path == expected_model_path
assert not is_override
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2]
)
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
assert vae_model_path == expected_vae_path
assert is_override

View File

@@ -0,0 +1,15 @@
__metadata__:
version: 3.0.0
sdxl/main/SDXL base:
path: sdxl/main/SDXL base 1_0
description: SDXL base v1.0
variant: normal
format: diffusers
sdxl/main/SDXL with VAE:
path: sdxl/main/SDXL base 1_0
description: SDXL with customized VAE
vae: sdxl/vae/sdxl-vae-fp16-fix/
variant: normal
format: diffusers