Compare commits

..

3 Commits

94 changed files with 874 additions and 1305 deletions

View File

@@ -18,22 +18,6 @@ Note that any releases marked as _pre-release_ are in a beta state. You may expe
The Model Manager tab in the UI provides a few ways to install models, including using your already-downloaded models. You'll see a popup directing you there on first startup. For more information, see the [model install docs].
## Missing models after updating to v4
If you find some models are missing after updating to v4, it's likely they weren't correctly registered before the update and didn't get picked up in the migration.
You can use the `Scan Folder` tab in the Model Manager UI to fix this. The models will either be in the old, now-unused `autoimport` folder, or your `models` folder.
- Find and copy your install's old `autoimport` folder path, install the main install folder.
- Go to the Model Manager and click `Scan Folder`.
- Paste the path and scan.
- IMPORTANT: Uncheck `Inplace install`.
- Click `Install All` to install all found models, or just install the models you want.
Next, find and copy your install's `models` folder path (this could be your custom models folder path, or the `models` folder inside the main install folder).
Follow the same steps to scan and import the missing models.
## Slow generation
- Check the [system requirements] to ensure that your system is capable of generating images.

View File

@@ -44,7 +44,7 @@ The installation process is simple, with a few prompts:
- Select the version to install. Unless you have a specific reason to install a specific version, select the default (the latest version).
- Select location for the install. Be sure you have enough space in this folder for the base application, as described in the [installation requirements].
- Select a GPU device.
- Select a GPU device. If you are unsure, you can let the installer figure it out.
!!! info "Slow Installation"

View File

@@ -6,7 +6,11 @@
## Introduction
InvokeAI is distributed as a python package on PyPI, installable with `pip`. There are a few things that are handled by the installer and launcher that you'll need to manage manually, described in this guide.
!!! tip "Conda"
As of InvokeAI v2.3.0 installation using the `conda` package manager is no longer being supported. It will likely still work, but we are not testing this installation method.
InvokeAI is distributed as a python package on PyPI, installable with `pip`. There are a few things that are handled by the installer that you'll need to manage manually, described in this guide.
### Requirements
@@ -36,11 +40,11 @@ Before you start, go through the [installation requirements].
1. Enter the root (invokeai) directory and create a virtual Python environment within it named `.venv`.
!!! warning "Virtual Environment Location"
!!! info "Virtual Environment Location"
While you may create the virtual environment anywhere in the file system, we recommend that you create it within the root directory as shown here. This allows the application to automatically detect its data directories.
If you choose a different location for the venv, then you _must_ set the `INVOKEAI_ROOT` environment variable or specify the root directory using the `--root` CLI arg.
If you choose a different location for the venv, then you must set the `INVOKEAI_ROOT` environment variable or pass the directory using the `--root` CLI arg.
```terminal
cd $INVOKEAI_ROOT
@@ -77,23 +81,31 @@ Before you start, go through the [installation requirements].
python3 -m pip install --upgrade pip
```
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
1. Install the InvokeAI Package. The `--extra-index-url` option is used to select the correct `torch` backend:
- You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command.
=== "CUDA (NVidia)"
!!! example "Install with an extra index URL"
```bash
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
```bash
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
```
=== "ROCm (AMD)"
- If you have a CUDA GPU and want to install with `xformers`, you need to add an option to the package name. Note that `xformers` is not necessary. PyTorch includes an implementation of the SDP attention algorithm with the same performance.
```bash
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.6
```
!!! example "Install with `xformers`"
=== "CPU (Intel Macs & non-GPU systems)"
```bash
pip install "InvokeAI[xformers]" --use-pep517
```
```bash
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
```
=== "MPS (Apple Silicon)"
```bash
pip install InvokeAI --use-pep517
```
1. Deactivate and reactivate your runtime directory so that the invokeai-specific commands become available in the environment:
@@ -114,6 +126,37 @@ Before you start, go through the [installation requirements].
Run `invokeai-web` to start the UI. You must activate the virtual environment before running the app.
!!! warning
If the virtual environment you selected is NOT inside `INVOKEAI_ROOT`, then you must specify the path to the root directory by adding
`--root_dir \path\to\invokeai`.
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.
!!! tip
You can permanently set the location of the runtime directory
by setting the environment variable `INVOKEAI_ROOT` to the
path of the directory. As mentioned previously, this is
recommended if your virtual environment is located outside of
your runtime directory.
## Unsupported Conda Install
Congratulations, you found the "secret" Conda installation instructions. If you really **really** want to use Conda with InvokeAI, you can do so using this unsupported recipe:
```sh
mkdir ~/invokeai
conda create -n invokeai python=3.11
conda activate invokeai
# Adjust this as described above for the appropriate torch backend
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu121
invokeai-web --root ~/invokeai
```
The `pip install` command shown in this recipe is for Linux/Windows
systems with an NVIDIA GPU. See step (6) above for the command to use
with other platforms/GPU combinations. If you don't wish to pass the
`--root` argument to `invokeai` with each launch, you may set the
environment variable `INVOKEAI_ROOT` to point to the installation directory.
Note that if you run into problems with the Conda installation, the InvokeAI
staff will **not** be able to help you out. Caveat Emptor!
[installation requirements]: INSTALL_REQUIREMENTS.md

View File

@@ -32,5 +32,5 @@ As described in the [frontend dev toolchain] docs, you can run the UI using a de
[Fork and clone]: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo
[InvokeAI repo]: https://github.com/invoke-ai/InvokeAI
[frontend dev toolchain]: ../contributing/frontend/OVERVIEW.md
[manual installation]: ./020_INSTALL_MANUAL.md
[manual installation]: installation/020_INSTALL_MANUAL.md
[editable install]: https://pip.pypa.io/en/latest/cli/pip_install/#cmdoption-e

View File

@@ -3,7 +3,6 @@
InvokeAI installer script
"""
import locale
import os
import platform
import re
@@ -317,9 +316,7 @@ def upgrade_pip(venv_path: Path) -> str | None:
python = str(venv_path.expanduser().resolve() / python)
try:
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode(
encoding=locale.getpreferredencoding()
)
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode()
except subprocess.CalledProcessError as e:
print(e)
result = None
@@ -407,29 +404,22 @@ def get_torch_source() -> Tuple[str | None, str | None]:
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
device = select_gpu()
# The correct extra index URLs for torch are inconsistent, see https://pytorch.org/get-started/locally/#start-locally
url = None
optional_modules: str | None = None
optional_modules = "[onnx]"
if OS == "Linux":
if device.value == "rocm":
url = "https://download.pytorch.org/whl/rocm5.6"
elif device.value == "cpu":
url = "https://download.pytorch.org/whl/cpu"
elif device.value == "cuda":
# CUDA uses the default PyPi index
optional_modules = "[xformers,onnx-cuda]"
elif OS == "Windows":
if device.value == "cuda":
url = "https://download.pytorch.org/whl/cu121"
optional_modules = "[xformers,onnx-cuda]"
elif device.value == "cpu":
# CPU uses the default PyPi index, no optional modules
pass
elif OS == "Darwin":
# macOS uses the default PyPi index, no optional modules
pass
if device.value == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu121"
optional_modules = "[xformers,onnx-directml]"
# Fall back to defaults
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
return (url, optional_modules)

View File

@@ -207,8 +207,10 @@ def dest_path(dest: Optional[str | Path] = None) -> Path | None:
class GpuType(Enum):
CUDA = "cuda"
CUDA_AND_DML = "cuda_and_dml"
ROCM = "rocm"
CPU = "cpu"
AUTODETECT = "autodetect"
def select_gpu() -> GpuType:
@@ -224,6 +226,10 @@ def select_gpu() -> GpuType:
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
GpuType.CUDA,
)
nvidia_with_dml = (
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
GpuType.CUDA_AND_DML,
)
amd = (
"an [gold1 b]AMD[/] GPU (using ROCm™)",
GpuType.ROCM,
@@ -232,19 +238,27 @@ def select_gpu() -> GpuType:
"Do not install any GPU support, use CPU for generation (slow)",
GpuType.CPU,
)
autodetect = (
"I'm not sure what to choose",
GpuType.AUTODETECT,
)
options = []
if OS == "Windows":
options = [nvidia, cpu]
options = [nvidia, nvidia_with_dml, cpu]
if OS == "Linux":
options = [nvidia, amd, cpu]
elif OS == "Darwin":
options = [cpu]
# future CoreML?
if len(options) == 1:
print(f'Your platform [gold1]{OS}-{ARCH}[/] only supports the "{options[0][1]}" driver. Proceeding with that.')
return options[0][1]
# "I don't know" is always added the last option
options.append(autodetect) # type: ignore
options = {str(i): opt for i, opt in enumerate(options, 1)}
console.rule(":space_invader: GPU (Graphics Card) selection :space_invader:")
@@ -278,6 +292,11 @@ def select_gpu() -> GpuType:
),
)
if options[choice][1] is GpuType.AUTODETECT:
console.print(
"No problem. We will install CUDA support first :crossed_fingers: If Invoke does not detect a GPU, please re-run the installer and select one of the other GPU types."
)
return options[choice][1]

View File

@@ -219,13 +219,28 @@ async def scan_for_models(
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
resolved_installed_model_paths: list[str] = []
installed_model_sources: list[str] = []
# This call lists all installed models.
for model in installed_models:
path = pathlib.Path(model.path)
# If the model has a source, we need to add it to the list of installed sources.
if model.source:
installed_model_sources.append(model.source)
# If the path is not absolute, that means it is in the app models directory, and we need to join it with
# the models path before resolving.
if not path.is_absolute():
resolved_installed_model_paths.append(str(pathlib.Path(models_path, path).resolve()))
continue
resolved_installed_model_paths.append(str(path.resolve()))
scan_results: list[FoundModel] = []
# Check if the model is installed by comparing paths, appending to the scan result.
# Check if the model is installed by comparing the resolved paths, appending to the scan result.
for p in non_core_model_paths:
path = str(p)
is_installed = any(str(models_path / m.path) == path for m in installed_models)
is_installed = path in resolved_installed_model_paths or path in installed_model_sources
found_model = FoundModel(path=path, is_installed=is_installed)
scan_results.append(found_model)
except Exception as e:
@@ -599,8 +614,8 @@ async def convert_model(
The return value is the model configuration for the converted model.
"""
model_manager = ApiDependencies.invoker.services.model_manager
loader = model_manager.load
logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install
@@ -615,13 +630,7 @@ async def convert_model(
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
# loading the model will convert it into a cached diffusers file
try:
cc_size = loader.convert_cache.max_size
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
loader._convert_cache.max_size = 1.0
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
finally:
loader._convert_cache.max_size = cc_size
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)

View File

@@ -9,7 +9,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora_model_patcher import LoraModelPatcher
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
@@ -80,7 +81,8 @@ class CompelInvocation(BaseInvocation):
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), "text_encoder"),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
):
@@ -181,7 +183,8 @@ class SDXLPromptInvocationBase:
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
LoraModelPatcher.apply_lora_to_text_encoder(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
@@ -259,15 +262,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
context, self.clip, self.prompt, False, "text_encoder", zero_on_empty=True
)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
context, self.clip2, self.prompt, True, "text_encoder_2", zero_on_empty=True
)
else:
c2, c2_pooled, ec2 = self.run_clip_compel(
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
context, self.clip2, self.style, True, "text_encoder_2", zero_on_empty=True
)
original_size = (self.original_height, self.original_width)

View File

@@ -3,7 +3,6 @@ Invoke-managed custom node loader. See README.md for more information.
"""
import sys
import traceback
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
@@ -42,15 +41,11 @@ for d in Path(__file__).parent.iterdir():
logger.info(f"Loading node pack {module_name}")
try:
module = module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
module = module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
loaded_count += 1
except Exception:
full_error = traceback.format_exc()
logger.error(f"Failed to load node pack {module_name}:\n{full_error}")
loaded_count += 1
del init, module_name

View File

@@ -1,22 +1,21 @@
from builtins import float
from typing import List, Literal, Union
from typing import List, Union
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
ModelType,
)
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
class IPAdapterField(BaseModel):
@@ -49,15 +48,12 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
# Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).", ui_order=1)
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
@@ -65,11 +61,7 @@ class IPAdapterInvocation(BaseInvocation):
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)
clip_vision_model: Literal["auto", "ViT-H", "ViT-G"] = InputField(
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
default="auto",
ui_order=2,
)
weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
@@ -94,21 +86,10 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
if self.clip_vision_model == "auto":
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:
raise RuntimeError(
"You need to set the appropriate CLIP Vision model for checkpoint IP Adapter models."
)
else:
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
@@ -121,25 +102,19 @@ class IPAdapterInvocation(BaseInvocation):
)
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
if not len(image_encoder_models) > 0:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed. \
Downloading and installing now. This may take a while."
)
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
found = False
while not found:
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
if len(image_encoder_models) == 0:
context.logger.error("Error while fetching CLIP Vision Image Encoder")
assert len(image_encoder_models) == 1
found = len(image_encoder_models) > 0
if not found:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
)
context.logger.warning("Downloading and installing now. This may take a while.")
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
assert len(image_encoder_models) == 1
return image_encoder_models[0]

View File

@@ -43,12 +43,17 @@ from invokeai.app.invocations.fields import (
WithMetadata,
)
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
from invokeai.app.invocations.primitives import (
DenoiseMaskOutput,
ImageOutput,
LatentsOutput,
)
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora_model_patcher import LoraModelPatcher
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
@@ -64,7 +69,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField
@@ -730,7 +740,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
# ModelPatcher.apply_lora_unet(unet, _lora_loader()),
LoraModelPatcher.apply_lora_to_unet(unet, _lora_loader()),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)

View File

@@ -2,8 +2,16 @@ from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@@ -35,7 +43,6 @@ class IPAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import locale
import os
import re
import shutil
@@ -318,10 +317,11 @@ class InvokeAIAppConfig(BaseSettings):
@staticmethod
def find_root() -> Path:
"""Choose the runtime root directory when not specified on command line or init file."""
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ["INVOKEAI_ROOT"])
elif venv := os.environ.get("VIRTUAL_ENV", None):
root = Path(venv).parent.resolve()
elif any((venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]):
root = (venv.parent).resolve()
else:
root = Path("~/invokeai").expanduser().resolve()
return root
@@ -373,16 +373,13 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
# The old default for this was "configs/stable-diffusion". If if the incoming config has that as the value, we won't set it.
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
# Else we do not attempt to migrate this setting
if v != "configs/stable-diffusion":
parsed_config_dict["legacy_conf_dir"] = v
elif Path(v).name == "stable-diffusion":
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
@@ -402,7 +399,7 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
with open(config_path) as file:
loaded_config_dict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)

View File

@@ -1,6 +1,5 @@
"""Model installation class."""
import locale
import os
import re
import signal
@@ -324,8 +323,7 @@ class ModelInstallService(ModelInstallServiceBase):
legacy_models_yaml_path = Path(self._app_config.root_path, legacy_models_yaml_path)
if legacy_models_yaml_path.exists():
with open(legacy_models_yaml_path, "rt", encoding=locale.getpreferredencoding()) as file:
legacy_models_yaml = yaml.safe_load(file)
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
yaml_metadata = legacy_models_yaml.pop("__metadata__")
yaml_version = yaml_metadata.get("version")
@@ -350,13 +348,8 @@ class ModelInstallService(ModelInstallServiceBase):
config: dict[str, Any] = {}
config["name"] = model_name
config["description"] = stanza.get("description")
legacy_config_path = stanza.get("config")
if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config["config_path"] = str(legacy_config_path)
config["config_path"] = stanza.get("config")
try:
id = self.register_path(model_path=model_path, config=config)
self._logger.info(f"Migrated {model_name} with id {id}")
@@ -375,13 +368,11 @@ class ModelInstallService(ModelInstallServiceBase):
def delete(self, key: str) -> None: # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self.record_store.get_model(key)
model_path = self.app_config.models_path / model.path
if model_path.is_relative_to(self.app_config.models_path):
# If the models is in the Invoke-managed models dir, we delete it
models_dir = self.app_config.models_path
model_path = models_dir / Path(model.path) # handle legacy relative model paths
if model_path.is_relative_to(models_dir):
self.unconditionally_delete(key)
else:
# Else we only unregister it, leaving the file in place
self.unregister(key)
def unconditionally_delete(self, key: str) -> None: # noqa D102
@@ -509,9 +500,9 @@ class ModelInstallService(ModelInstallServiceBase):
def _scan_for_missing_models(self) -> list[AnyModelConfig]:
"""Scan the models directory for missing models and return a list of them."""
missing_models: list[AnyModelConfig] = []
for model_config in self.record_store.all_models():
if not (self.app_config.models_path / model_config.path).resolve().exists():
missing_models.append(model_config)
for x in self.record_store.all_models():
if not Path(x.path).resolve().exists():
missing_models.append(x)
return missing_models
def _register_orphaned_models(self) -> None:
@@ -521,9 +512,7 @@ class ModelInstallService(ModelInstallServiceBase):
only situations in which we may have orphaned models in the models directory.
"""
installed_model_paths = {
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
}
installed_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
# The bool returned by this callback determines if the model is added to the list of models found by the search
def on_model_found(model_path: Path) -> bool:
@@ -559,21 +548,20 @@ class ModelInstallService(ModelInstallServiceBase):
May raise an UnknownModelException.
"""
model = self.record_store.get_model(key)
models_dir = self.app_config.models_path
old_path = self.app_config.models_path / model.path
old_path = Path(model.path).resolve()
models_dir = self.app_config.models_path.resolve()
if not old_path.is_relative_to(models_dir):
# The model is not in the models directory - we don't need to move it.
return model
new_path = models_dir / model.base.value / model.type.value / old_path.name
new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix)
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
return model
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
model.path = new_path.relative_to(models_dir).as_posix()
model.path = new_path.as_posix()
self.record_store.update_model(key, ModelRecordChanges(path=model.path))
return model
@@ -612,19 +600,12 @@ class ModelInstallService(ModelInstallServiceBase):
model_path = model_path.resolve()
# Models in the Invoke-managed models dir should use relative paths.
if model_path.is_relative_to(self.app_config.models_path):
model_path = model_path.relative_to(self.app_config.models_path)
info.path = model_path.as_posix()
# Checkpoints have a config file needed for conversion - resolve this to an absolute path
if isinstance(info, CheckpointConfigBase):
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
# invoke-managed legacy config dir, we use a relative path.
legacy_config_path = self.app_config.legacy_conf_path / info.config_path
if legacy_config_path.is_relative_to(self.app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self.app_config.legacy_conf_path)
info.config_path = legacy_config_path.as_posix()
legacy_conf = (self.app_config.legacy_conf_path / info.config_path).resolve()
info.config_path = legacy_conf.as_posix()
self.record_store.add_model(info)
return info.key

View File

@@ -70,18 +70,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
if (
event_name == "session_canceled"
and self._queue_item
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
):
self._cancel_event.set()
self._poll_now()
elif (
event_name == "queue_cleared"
and self._queue_item
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
):
if event_name == "session_canceled" or event_name == "queue_cleared":
# These both mean we should cancel the current session.
self._cancel_event.set()
self._poll_now()
elif event_name == "batch_enqueued":
@@ -121,146 +111,141 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.clear()
# Middle processor try block; any unhandled exception is a non-fatal processor error
try:
# If we are paused, wait for resume event
resume_event.wait()
# Get the next session to process
self._queue_item = self._invoker.services.session_queue.dequeue()
if self._queue_item is not None and resume_event.is_set():
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
if self._queue_item is None:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# Prepare invocations and take the first
self._invocation = self._queue_item.session.next()
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Prepare invocations and take the first
self._invocation = self._queue_item.session.next()
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
)
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
)
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
error_type=e.__class__.__name__,
error=error,
)
pass
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
# The session is complete, immediately poll for next session
self._queue_item = None
poll_now_event.set()
else:
# The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")

View File

@@ -10,8 +10,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -39,8 +37,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_5())
migrator.register_migration(build_migration_6())
migrator.register_migration(build_migration_7())
migrator.register_migration(build_migration_8(app_config=config))
migrator.register_migration(build_migration_9())
migrator.run_migrations()
return db

View File

@@ -11,7 +11,7 @@ class Migration7Callback:
def _drop_old_models_tables(self, cursor: sqlite3.Cursor) -> None:
"""Drops the old model_records, model_metadata, model_tags and tags tables."""
tables = ["model_config", "model_metadata", "model_tags", "tags"]
tables = ["model_records", "model_metadata", "model_tags", "tags"]
for table in tables:
cursor.execute(f"DROP TABLE IF EXISTS {table};")

View File

@@ -1,91 +0,0 @@
import sqlite3
from pathlib import Path
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration8Callback:
def __init__(self, app_config: InvokeAIAppConfig) -> None:
self._app_config = app_config
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._drop_model_config_table(cursor)
self._migrate_abs_models_to_rel(cursor)
def _drop_model_config_table(self, cursor: sqlite3.Cursor) -> None:
"""Drops the old model_config table. This was missed in a previous migration."""
cursor.execute("DROP TABLE IF EXISTS model_config;")
def _migrate_abs_models_to_rel(self, cursor: sqlite3.Cursor) -> None:
"""Check all model paths & legacy config paths to determine if they are inside Invoke-managed directories. If
they are, update the paths to be relative to the managed directories.
This migration is a no-op for normal users (their paths will already be relative), but is necessary for users
who have been testing the RCs with their live databases. The paths were made absolute in the initial RC, but this
change was reverted. To smooth over the revert for our tests, we can migrate the paths back to relative.
"""
models_path = self._app_config.models_path
legacy_conf_path = self._app_config.legacy_conf_path
legacy_conf_dir = self._app_config.legacy_conf_dir
stmt = """---sql
SELECT
id,
path,
json_extract(config, '$.config_path') as config_path
FROM models;
"""
all_models = cursor.execute(stmt).fetchall()
for model_id, model_path, model_config_path in all_models:
# If the model path is inside the models directory, update it to be relative to the models directory.
if Path(model_path).is_relative_to(models_path):
new_path = Path(model_path).relative_to(models_path)
cursor.execute(
"""--sql
UPDATE models
SET config = json_set(config, '$.path', ?)
WHERE id = ?;
""",
(str(new_path), model_id),
)
# If the model has a legacy config path and it is inside the legacy conf directory, update it to be
# relative to the legacy conf directory. This also fixes up cases in which the config path was
# incorrectly relativized to the root directory. It will now be relativized to the legacy conf directory.
if model_config_path:
if Path(model_config_path).is_relative_to(legacy_conf_path):
new_config_path = Path(model_config_path).relative_to(legacy_conf_path)
elif Path(model_config_path).is_relative_to(legacy_conf_dir):
new_config_path = Path(*Path(model_config_path).parts[1:])
else:
new_config_path = None
if new_config_path:
cursor.execute(
"""--sql
UPDATE models
SET config = json_set(config, '$.config_path', ?)
WHERE id = ?;
""",
(str(new_config_path), model_id),
)
def build_migration_8(app_config: InvokeAIAppConfig) -> Migration:
"""
Build the migration from database version 7 to 8.
This migration does the following:
- Removes the `model_config` table.
- Migrates absolute model & legacy config paths to be relative to the models directory.
"""
migration_8 = Migration(
from_version=7,
to_version=8,
callback=Migration8Callback(app_config),
)
return migration_8

View File

@@ -1,29 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration9Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._empty_session_queue(cursor)
def _empty_session_queue(self, cursor: sqlite3.Cursor) -> None:
"""Empties the session queue. This is done to prevent any lingering session queue items from causing pydantic errors due to changed schemas."""
cursor.execute("DELETE FROM session_queue;")
def build_migration_9() -> Migration:
"""
Build the migration from database version 8 to 9.
This migration does the following:
- Empties the session queue. This is done to prevent any lingering session queue items from causing pydantic errors due to changed schemas.
"""
migration_9 = Migration(
from_version=8,
to_version=9,
callback=Migration9Callback(),
)
return migration_9

View File

@@ -1,6 +1,4 @@
import sqlite3
from contextlib import closing
from datetime import datetime
from pathlib import Path
from typing import Optional
@@ -34,7 +32,6 @@ class SqliteMigrator:
self._db = db
self._logger = db.logger
self._migration_set = MigrationSet()
self._backup_path: Optional[Path] = None
def register_migration(self, migration: Migration) -> None:
"""Registers a migration."""
@@ -58,18 +55,6 @@ class SqliteMigrator:
return False
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)

View File

@@ -1,32 +1,21 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
import pathlib
from typing import List, Optional, TypedDict, Union
from typing import Optional, Union
import safetensors
import safetensors.torch
import torch
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from ..raw_model import RawModel
from .resampler import Resampler
class IPAdapterStateDict(TypedDict):
ip_adapter: dict[str, torch.Tensor]
image_proj: dict[str, torch.Tensor]
class ImageProjModel(torch.nn.Module):
"""Image Projection Model"""
def __init__(
self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4
):
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
@@ -35,7 +24,7 @@ class ImageProjModel(torch.nn.Module):
self.norm = torch.nn.LayerNorm(cross_attention_dim)
@classmethod
def from_state_dict(cls, state_dict: dict[str, torch.Tensor], clip_extra_context_tokens: int = 4):
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
"""Initialize an ImageProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
@@ -55,7 +44,7 @@ class ImageProjModel(torch.nn.Module):
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds: torch.Tensor):
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
@@ -67,7 +56,7 @@ class ImageProjModel(torch.nn.Module):
class MLPProjModel(torch.nn.Module):
"""SD model with image prompt"""
def __init__(self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024):
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
super().__init__()
self.proj = torch.nn.Sequential(
@@ -78,7 +67,7 @@ class MLPProjModel(torch.nn.Module):
)
@classmethod
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
"""Initialize an MLPProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
@@ -97,17 +86,17 @@ class MLPProjModel(torch.nn.Module):
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds: torch.Tensor):
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class IPAdapter(RawModel):
class IPAdapter(torch.nn.Module):
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
def __init__(
self,
state_dict: IPAdapterStateDict,
state_dict: dict[str, torch.Tensor],
device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
@@ -139,27 +128,24 @@ class IPAdapter(RawModel):
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
def _init_image_proj_model(
self, state_dict: dict[str, torch.Tensor]
) -> Union[ImageProjModel, Resampler, MLPProjModel]:
def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
@torch.inference_mode()
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
try:
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
except RuntimeError as e:
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]) -> Union[Resampler, MLPProjModel]:
def _init_image_proj_model(self, state_dict):
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
@@ -170,32 +156,31 @@ class IPAdapterPlus(IPAdapter):
).to(self.device, dtype=self.dtype)
@torch.inference_mode()
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=self.dtype)
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
-2
]
try:
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
except RuntimeError as e:
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter Plus with full features."""
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
class IPAdapterPlusXL(IPAdapterPlus):
"""IP-Adapter Plus for SDXL."""
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
def _init_image_proj_model(self, state_dict):
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
@@ -206,48 +191,24 @@ class IPAdapterPlusXL(IPAdapterPlus):
).to(self.device, dtype=self.dtype)
def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
if ip_adapter_ckpt_path.suffix == ".safetensors":
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
for key in model.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
else:
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
else:
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path / "ip_adapter.bin"
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
return state_dict
def build_ip_adapter(
ip_adapter_ckpt_path: pathlib.Path, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus]:
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
# IPAdapter (with ImageProjModel)
if "proj.weight" in state_dict["image_proj"]:
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
return IPAdapter(state_dict, device=device, dtype=dtype)
# IPAdaterPlus or IPAdapterPlusXL (with Resampler)
elif "proj_in.weight" in state_dict["image_proj"]:
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return IPAdapterPlus(state_dict, device=device, dtype=dtype) # SD1 IP-Adapter Plus
# SD1 IP-Adapter Plus
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
elif cross_attention_dim == 2048:
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) # SDXL IP-Adapter Plus
# SDXL IP-Adapter Plus
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
else:
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
# IPAdapterFull (with MLPProjModel)
elif "proj.0.weight" in state_dict["image_proj"]:
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
return IPAdapterFull(state_dict, device=device, dtype=dtype)
# Unrecognized IP Adapter Architectures
else:
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")

View File

@@ -9,8 +9,8 @@ import torch.nn as nn
# FFN
def FeedForward(dim: int, mult: int = 4):
inner_dim = dim * mult
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
@@ -19,8 +19,8 @@ def FeedForward(dim: int, mult: int = 4):
)
def reshape_tensor(x: torch.Tensor, heads: int):
bs, length, _ = x.shape
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
@@ -31,7 +31,7 @@ def reshape_tensor(x: torch.Tensor, heads: int):
class PerceiverAttention(nn.Module):
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
@@ -45,7 +45,7 @@ class PerceiverAttention(nn.Module):
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x: torch.Tensor, latents: torch.Tensor):
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
@@ -80,14 +80,14 @@ class PerceiverAttention(nn.Module):
class Resampler(nn.Module):
def __init__(
self,
dim: int = 1024,
depth: int = 8,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
embedding_dim: int = 768,
output_dim: int = 1024,
ff_mult: int = 4,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
):
super().__init__()
@@ -110,15 +110,7 @@ class Resampler(nn.Module):
)
@classmethod
def from_state_dict(
cls,
state_dict: dict[str, torch.Tensor],
depth: int = 8,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
ff_mult: int = 4,
):
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
"""A convenience function that initializes a Resampler from a state_dict.
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
@@ -153,7 +145,7 @@ class Resampler(nn.Module):
model.load_state_dict(state_dict)
return model
def forward(self, x: torch.Tensor):
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)

View File

@@ -0,0 +1,65 @@
from contextlib import contextmanager
from typing import Iterator, Tuple, Union
from diffusers.loaders.lora import LoraLoaderMixin
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.utils.peft_utils import recurse_remove_peft_layers
from transformers import CLIPTextModel
from invokeai.backend.lora_model_raw import LoRAModelRaw
class LoraModelPatcher:
@classmethod
def unload_lora_from_model(cls, m: Union[UNet2DConditionModel, CLIPTextModel]):
"""Unload all LoRA models from a UNet or Text Encoder.
This implementation is base on LoraLoaderMixin.unload_lora_weights().
"""
recurse_remove_peft_layers(m)
if hasattr(m, "peft_config"):
del m.peft_config # type: ignore
if hasattr(m, "_hf_peft_config_loaded"):
m._hf_peft_config_loaded = None # type: ignore
@classmethod
@contextmanager
def apply_lora_to_unet(cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]]):
try:
# TODO(ryand): Test speed of low_cpu_mem_usage=True.
for lora, lora_weight in loras:
LoraLoaderMixin.load_lora_into_unet(
state_dict=lora.state_dict,
network_alphas=lora.network_alphas,
unet=unet,
low_cpu_mem_usage=True,
adapter_name=lora.name,
_pipeline=None,
)
yield
finally:
cls.unload_lora_from_model(unet)
@classmethod
@contextmanager
def apply_lora_to_text_encoder(
cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str
):
assert prefix in ["text_encoder", "text_encoder_2"]
try:
for lora, lora_weight in loras:
# Filter the state_dict to only include the keys that start with the prefix.
text_encoder_state_dict = {
key: value for key, value in lora.state_dict.items() if key.startswith(prefix + ".")
}
if len(text_encoder_state_dict) > 0:
LoraLoaderMixin.load_lora_into_text_encoder(
state_dict=text_encoder_state_dict,
network_alphas=lora.network_alphas,
text_encoder=text_encoder,
low_cpu_mem_usage=True,
adapter_name=lora.name,
_pipeline=None,
)
yield
finally:
cls.unload_lora_from_model(text_encoder)

View File

@@ -0,0 +1,66 @@
from pathlib import Path
from typing import Optional, Union
import torch
from diffusers.loaders.lora import LoraLoaderMixin
from typing_extensions import Self
class LoRAModelRaw:
def __init__(
self,
name: str,
state_dict: dict[str, torch.Tensor],
network_alphas: Optional[dict[str, float]],
):
self._name = name
self.state_dict = state_dict
self.network_alphas = network_alphas
@property
def name(self) -> str:
return self._name
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
for key, layer in self.state_dict.items():
self.state_dict[key] = layer.to(device=device, dtype=dtype)
def calc_size(self) -> int:
"""Calculate the size of the model in bytes."""
model_size = 0
for layer in self.state_dict.values():
model_size += layer.numel() * layer.element_size()
return model_size
@classmethod
def from_checkpoint(
cls, file_path: Union[str, Path], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> Self:
"""This function is based on diffusers LoraLoaderMixin.load_lora_weights()."""
file_path = Path(file_path)
if file_path.is_dir():
raise NotImplementedError("LoRA models from directories are not yet supported.")
dir_path = file_path.parent
file_name = file_path.name
state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
pretrained_model_name_or_path_or_dict=str(file_path), local_files_only=True, weight_name=str(file_name)
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
model = cls(
# TODO(ryand): Handle both files and directories here?
name=Path(file_path).stem,
state_dict=state_dict,
network_alphas=network_alphas,
)
device = device or torch.device("cpu")
dtype = dtype or torch.float32
model.to(device=device, dtype=dtype)
return model

View File

@@ -11,8 +11,6 @@ from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from .raw_model import RawModel
class LoRALayerBase:
# rank: Optional[int]
@@ -368,7 +366,7 @@ class IA3Layer(LoRALayerBase):
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
class LoRAModelRaw(RawModel): # (torch.nn.Module):
class LoRAModelRaw(torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]

View File

@@ -33,3 +33,42 @@ __all__ = [
"SchedulerPredictionType",
"SubModelType",
]
########## to help populate the openapi_schema with format enums for each config ###########
# This code is no longer necessary?
# leave it here just in case
#
# import inspect
# from enum import Enum
# from typing import Any, Iterable, Dict, get_args, Set
# def _expand(something: Any) -> Iterable[type]:
# if isinstance(something, type):
# yield something
# else:
# for x in get_args(something):
# for y in _expand(x):
# yield y
# def _find_format(cls: type) -> Iterable[Enum]:
# if hasattr(inspect, "get_annotations"):
# fields = inspect.get_annotations(cls)
# else:
# fields = cls.__annotations__
# if "format" in fields:
# for x in get_args(fields["format"]):
# yield x
# for parent_class in cls.__bases__:
# for x in _find_format(parent_class):
# yield x
# return None
# def get_model_config_formats() -> Dict[str, Set[Enum]]:
# result: Dict[str, Set[Enum]] = {}
# for model_config in _expand(AnyModelConfig):
# for field in _find_format(model_config):
# if field is None:
# continue
# if not result.get(model_config.__qualname__):
# result[model_config.__qualname__] = set()
# result[model_config.__qualname__].add(field)
# return result

View File

@@ -31,12 +31,13 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from ..raw_model import RawModel
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, LoRAModelRaw, TextualInversionModelRaw, IAIOnnxRuntimeModel]
class InvalidModelConfigException(Exception):
@@ -323,13 +324,10 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
class IPAdapterBaseConfig(ModelConfigBase):
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
"""Model config for IP Adapter diffusers format models."""
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]
@@ -338,16 +336,6 @@ class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""
@@ -403,8 +391,7 @@ AnyModelConfig = Annotated[
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],

View File

@@ -3,10 +3,10 @@
"""Conversion script for the Stable Diffusion checkpoints."""
from pathlib import Path
from typing import Optional
from typing import Dict
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
convert_ldm_vae_checkpoint,
create_vae_diffusers_config,
@@ -15,14 +15,11 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
)
from omegaconf import DictConfig
from . import AnyModel
def convert_ldm_vae_to_diffusers(
checkpoint: torch.Tensor | dict[str, torch.Tensor],
checkpoint: Dict[str, torch.Tensor],
vae_config: DictConfig,
image_size: int,
dump_path: Optional[Path] = None,
precision: torch.dtype = torch.float16,
) -> AutoencoderKL:
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
@@ -31,21 +28,16 @@ def convert_ldm_vae_to_diffusers(
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
vae.to(precision)
if dump_path:
vae.save_pretrained(dump_path, safe_serialization=True)
return vae
return vae.to(precision)
def convert_ckpt_to_diffusers(
checkpoint_path: str | Path,
dump_path: Optional[str | Path] = None,
dump_path: str | Path,
precision: torch.dtype = torch.float16,
use_safetensors: bool = True,
**kwargs,
) -> AnyModel:
):
"""
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
@@ -55,20 +47,18 @@ def convert_ckpt_to_diffusers(
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:
pipe.save_pretrained(
dump_path,
safe_serialization=use_safetensors,
)
return pipe
pipe.save_pretrained(
dump_path,
safe_serialization=use_safetensors,
)
def convert_controlnet_to_diffusers(
checkpoint_path: Path,
dump_path: Optional[Path] = None,
dump_path: Path,
precision: torch.dtype = torch.float16,
**kwargs,
) -> AnyModel:
):
"""
Takes all the arguments of download_controlnet_from_original_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
@@ -78,6 +68,4 @@ def convert_controlnet_to_diffusers(
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:
pipe.save_pretrained(dump_path, safe_serialization=True)
return pipe
pipe.save_pretrained(dump_path, safe_serialization=True)

View File

@@ -19,20 +19,11 @@ class ModelConvertCache(ModelConvertCacheBase):
self._cache_path = cache_path
self._max_size = max_size
# adjust cache size at startup in case it has been changed
if self._cache_path.exists():
self.make_room(0.0)
@property
def max_size(self) -> float:
"""Return the maximum size of this cache directory (GB)."""
return self._max_size
@max_size.setter
def max_size(self, value: float) -> None:
"""Set the maximum size of this cache directory (GB)."""
self._max_size = value
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
return self._cache_path / key

View File

@@ -83,15 +83,3 @@ class ModelLoaderBase(ABC):
) -> int:
"""Return size in bytes of the model, calculated before loading."""
pass
@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated with this loader."""
pass
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the ram cache associated with this loader."""
pass

View File

@@ -3,13 +3,14 @@
from logging import Logger
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
InvalidModelConfigException,
ModelRepoVariant,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
@@ -53,43 +54,51 @@ class ModelLoader(ModelLoaderBase):
if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path = self._get_model_path(model_config)
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
if not model_path.exists():
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
with skip_torch_weight_init():
locker = self._convert_and_load(model_config, model_path, submodel_type)
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
locker = self._load_if_needed(model_config, model_path, submodel_type)
return LoadedModel(config=model_config, _locker=locker)
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated with this loader."""
return self._convert_cache
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the ram cache associated with this loader."""
return self._ram_cache
def _get_model_path(self, config: AnyModelConfig) -> Path:
def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
model_base = self._app_config.models_path
return (model_base / config.path).resolve()
result = (model_base / config.path).resolve(), config, submodel_type
return result
def _convert_and_load(
def _convert_if_needed(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> Path:
cache_path: Path = self._convert_cache.cache_path(config.key)
if not self._needs_conversion(config, model_path, cache_path):
return cache_path if cache_path.exists() else model_path
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
def _load_if_needed(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> ModelLockerBase:
# TO DO: This is not thread safe!
try:
return self._ram_cache.get(config.key, submodel_type)
except IndexError:
pass
cache_path: Path = self._convert_cache.cache_path(config.key)
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else:
config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
loaded_model = self._load_model(config, submodel_type)
model_variant = getattr(config, "repo_variant", None)
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
# This is where the model is actually loaded!
with skip_torch_weight_init():
loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type)
self._ram_cache.put(
config.key,
@@ -114,34 +123,15 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
)
def _do_convert(
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
) -> AnyModel:
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
if submodel_type:
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
# the entire pipeline every time a new submodel is needed.
for subtype in SubModelType:
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
)
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
# This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
raise NotImplementedError
# This needs to be implemented in the subclass
def _load_model(
self,
config: AnyModelConfig,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
raise NotImplementedError

View File

@@ -122,11 +122,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Return the cap on cache size."""
return self._max_cache_size
@max_cache_size.setter
def max_cache_size(self, value: float) -> None:
"""Set the cap on cache size."""
self._max_cache_size = value
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
@@ -162,9 +157,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
) -> None:
"""Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
self.make_room(size)
assert key not in self._cached_models
cache_record = CacheRecord(key, model, size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@@ -411,8 +405,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
if self.stats:
self.stats.cleared = models_cleared
gc.collect()
torch.cuda.empty_cache()
@@ -429,8 +421,4 @@ class ModelCache(ModelCacheBase[AnyModel]):
)
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
if needed_size > free_mem:
needed_gb = round(needed_size / GIG, 2)
free_gb = round(free_mem / GIG, 2)
raise torch.cuda.OutOfMemoryError(
f"Insufficient VRAM to load model, requested {needed_gb}GB but only had {free_gb}GB free"
)
raise torch.cuda.OutOfMemoryError

View File

@@ -2,10 +2,8 @@
"""Class for ControlNet model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
@@ -35,7 +33,7 @@ class ControlNetLoader(GenericDiffusersLoader):
else:
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
assert isinstance(config, CheckpointConfigBase)
image_size = (
512
@@ -46,8 +44,8 @@ class ControlNetLoader(GenericDiffusersLoader):
)
self._logger.info(f"Converting {model_path} to diffusers format")
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
result = convert_controlnet_to_diffusers(
with open(self._app_config.root_path / config.config_path, "r") as config_stream:
convert_controlnet_to_diffusers(
model_path,
output_path,
original_config_file=config_stream,
@@ -55,4 +53,4 @@ class ControlNetLoader(GenericDiffusersLoader):
precision=self._torch_dtype,
from_safetensors=model_path.suffix == ".safetensors",
)
return result
return output_path

View File

@@ -10,14 +10,13 @@ from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase
from .. import ModelLoader, ModelLoaderRegistry
@@ -29,15 +28,14 @@ class GenericDiffusersLoader(ModelLoader):
def _load_model(
self,
config: AnyModelConfig,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
model_path = Path(config.path)
model_class = self.get_hf_load_class(model_path)
if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}")
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
variant = repo_variant.value if repo_variant else None
variant = model_variant.value if model_variant else None
try:
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)
except OSError as e:

View File

@@ -7,26 +7,31 @@ from typing import Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager import (
AnyModel,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.raw_model import RawModel
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint)
class IPAdapterInvokeAILoader(ModelLoader):
"""Class to load IP Adapter diffusers models."""
def _load_model(
self,
config: AnyModelConfig,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.")
model_path = Path(config.path)
model: RawModel = build_ip_adapter(
ip_adapter_ckpt_path=model_path,
model = build_ip_adapter(
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
device=torch.device("cpu"),
dtype=self._torch_dtype,
)

View File

@@ -3,15 +3,16 @@
from logging import Logger
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.lora_model_raw import LoRAModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
@@ -40,24 +41,26 @@ class LoRALoader(ModelLoader):
def _load_model(
self,
config: AnyModelConfig,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("There are no submodels in a LoRA model.")
model_path = Path(config.path)
assert self._model_base is not None
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
return model
# override
def _get_model_path(self, config: AnyModelConfig) -> Path:
# cheating a little - we remember this variable for using in the subsequent call to _load_model()
self._model_base = config.base
def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
self._model_base = (
config.base
) # cheating a little - we remember this variable for using in the subsequent call to _load_model()
model_base_path = self._app_config.models_path
model_path = model_base_path / config.path
@@ -69,4 +72,5 @@ class LoRALoader(ModelLoader):
model_path = path
break
return model_path.resolve()
result = model_path.resolve(), config, submodel_type
return result

View File

@@ -7,9 +7,9 @@ from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
@@ -25,19 +25,18 @@ class OnnyxDiffusersModel(GenericDiffusersLoader):
def _load_model(
self,
config: AnyModelConfig,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading onnx pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = getattr(config, "repo_variant", None)
variant = repo_variant.value if repo_variant else None
variant = model_variant.value if model_variant else None
model_path = model_path / submodel_type.value
result: AnyModel = load_class.from_pretrained(
model_path,
torch_dtype=self._torch_dtype,
variant=variant,
)
) # type: ignore
return result

View File

@@ -9,16 +9,12 @@ from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
DiffusersConfigBase,
MainCheckpointConfig,
ModelVariantType,
)
from invokeai.backend.model_manager.config import CheckpointConfigBase, MainCheckpointConfig, ModelVariantType
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from .. import ModelLoaderRegistry
@@ -45,15 +41,14 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
def _load_model(
self,
config: AnyModelConfig,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading main pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
variant = repo_variant.value if repo_variant else None
variant = model_variant.value if model_variant else None
model_path = model_path / submodel_type.value
try:
result: AnyModel = load_class.from_pretrained(
@@ -83,7 +78,7 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
else:
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
assert isinstance(config, MainCheckpointConfig)
base = config.base
@@ -99,11 +94,11 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
self._logger.info(f"Converting {model_path} to diffusers format")
loaded_model = convert_ckpt_to_diffusers(
convert_ckpt_to_diffusers(
model_path,
output_path,
model_type=self.model_base_to_model_type[base],
original_config_file=self._app_config.legacy_conf_path / config.config_path,
original_config_file=self._app_config.root_path / config.config_path,
extract_ema=True,
from_safetensors=model_path.suffix == ".safetensors",
precision=self._torch_dtype,
@@ -113,4 +108,4 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
load_safety_checker=False,
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
)
return loaded_model
return output_path

View File

@@ -2,13 +2,14 @@
"""Class for TI model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
@@ -26,19 +27,22 @@ class TextualInversionLoader(ModelLoader):
def _load_model(
self,
config: AnyModelConfig,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("There are no submodels in a TI model.")
model = TextualInversionModelRaw.from_checkpoint(
file_path=config.path,
file_path=model_path,
dtype=self._torch_dtype,
)
return model
# override
def _get_model_path(self, config: AnyModelConfig) -> Path:
def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
model_path = self._app_config.models_path / config.path
if config.format == ModelFormat.EmbeddingFolder:
@@ -49,4 +53,4 @@ class TextualInversionLoader(ModelLoader):
if not path.exists():
raise OSError(f"The embedding file at {path} was not found")
return path
return path, config, submodel_type

View File

@@ -2,7 +2,6 @@
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
import torch
from omegaconf import DictConfig, OmegaConf
@@ -14,7 +13,7 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from .. import ModelLoaderRegistry
@@ -39,13 +38,13 @@ class VAELoader(GenericDiffusersLoader):
else:
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
# TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"VAE conversion not supported for model type: {config.base}")
else:
assert isinstance(config, CheckpointConfigBase)
config_file = self._app_config.legacy_conf_path / config.config_path
config_file = self._app_config.root_path / config.config_path
if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu")
@@ -64,6 +63,6 @@ class VAELoader(GenericDiffusersLoader):
vae_config=ckpt_config,
image_size=512,
precision=self._torch_dtype,
dump_path=output_path,
)
return vae_model
vae_model.save_pretrained(output_path, safe_serialization=True)
return output_path

View File

@@ -230,10 +230,9 @@ class ModelProbe(object):
return ModelType.LoRA
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
return ModelType.ControlNet
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
@@ -324,7 +323,7 @@ class ModelProbe(object):
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
cls._scan_model(model_path.name, model_path)
model = torch.load(model_path, map_location="cpu")
model = torch.load(model_path)
assert isinstance(model, dict)
return model
else:
@@ -528,25 +527,8 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
class IPAdapterCheckpointProbe(CheckpointProbeBase):
"""Class for probing IP Adapters"""
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
for key in checkpoint.keys():
if not key.startswith(("image_proj.", "ip_adapter.")):
continue
cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
)
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
@@ -786,7 +768,7 @@ class T2IAdapterFolderProbe(FolderProbeBase):
)
# Register probe classes
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)

View File

@@ -17,7 +17,7 @@ from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from .lora import LoRAModelRaw
from .lora_model_raw import LoRAModelRaw
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
"""

View File

@@ -6,17 +6,16 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import onnx
import torch
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from ..raw_model import RawModel
ONNX_WEIGHTS_NAME = "model.onnx"
# NOTE FROM LS: This was copied from Stalker's original implementation.
# I have not yet gone through and fixed all the type hints
class IAIOnnxRuntimeModel(RawModel):
class IAIOnnxRuntimeModel(torch.nn.Module):
class _tensor_access:
def __init__(self, model): # type: ignore
self.model = model

View File

@@ -1,15 +0,0 @@
"""Base class for 'Raw' models.
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
and is used for type checking of calls to the model patcher. Its main purpose
is to avoid a circular import issues when lora.py tries to import BaseModelType
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
from lora.py.
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes.
"""
class RawModel:
"""Base class for 'Raw' model wrappers."""

View File

@@ -28,10 +28,6 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
if not seamless_axes:
yield
return
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
try:

View File

@@ -9,10 +9,8 @@ from safetensors.torch import load_file
from transformers import CLIPTokenizer
from typing_extensions import Self
from .raw_model import RawModel
class TextualInversionModelRaw(RawModel):
class TextualInversionModelRaw(torch.nn.Module):
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models

View File

@@ -31,9 +31,6 @@ class ConfigMapper:
YAML_FILENAME = "invokeai.yaml"
DATABASE_FILENAME = "invokeai.db"
DEFAULT_OUTDIR = "outputs"
DEFAULT_DB_DIR = "databases"
database_path = None
database_backup_dir = None
outputs_path = None
@@ -53,18 +50,12 @@ class ConfigMapper:
def __load_from_root_config(self, invoke_root):
"""Validate a yaml path exists, confirm the user wants to use it and load config."""
yaml_path = os.path.join(invoke_root, self.YAML_FILENAME)
if not os.path.exists(yaml_path):
print(f"Unable to find invokeai.yaml at {yaml_path}!")
return False
if os.path.exists(yaml_path):
db_dir, outdir = self.__load_paths_from_yaml_file(yaml_path)
if db_dir is None:
db_dir = self.DEFAULT_DB_DIR
print(f"The invokeai.yaml file was found but is missing the db_dir setting! Defaulting to {db_dir}")
if outdir is None:
outdir = self.DEFAULT_OUTDIR
print(f"The invokeai.yaml file was found but is missing the outdir setting! Defaulting to {outdir}")
if db_dir is None or outdir is None:
print("The invokeai.yaml file was found but is missing the db_dir and/or outdir setting!")
return False
if os.path.isabs(db_dir):
self.database_path = os.path.join(db_dir, self.DATABASE_FILENAME)

View File

@@ -94,7 +94,6 @@
"reactflow": "^11.10.4",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.1.0",
"rfdc": "^1.3.1",
"roarr": "^7.21.1",
"serialize-error": "^11.0.3",
"socket.io-client": "^4.7.5",

View File

@@ -137,9 +137,6 @@ dependencies:
redux-remember:
specifier: ^5.1.0
version: 5.1.0(redux@5.0.1)
rfdc:
specifier: ^1.3.1
version: 1.3.1
roarr:
specifier: ^7.21.1
version: 7.21.1
@@ -12131,10 +12128,6 @@ packages:
resolution: {integrity: sha512-/x8uIPdTafBqakK0TmPNJzgkLP+3H+yxpUJhCQHsLBg1rYEVNR2D8BRYNWQhVBjyOd7oo1dZRVzIkwMY2oqfYQ==}
dev: true
/rfdc@1.3.1:
resolution: {integrity: sha512-r5a3l5HzYlIC68TpmYKlxWjmOP6wiPJ1vWv2HeLhNsRZMrCkxeqxiHlQ21oXmQ4F3SiryXBHhAD7JZqvOJjFmg==}
dev: false
/rimraf@2.6.3:
resolution: {integrity: sha512-mwqeW5XsA2qAejG46gYdENaxXjx9onRNCfn7L0duuP4hCuTIi/QO7PDK07KJfp1d+izWPrzEJDcSqBa0OZQriA==}
hasBin: true

View File

@@ -4,7 +4,7 @@
"reportBugLabel": "Fehler melden",
"settingsLabel": "Einstellungen",
"img2img": "Bild zu Bild",
"nodes": "Arbeitsabläufe",
"nodes": "Knoten Editor",
"upload": "Hochladen",
"load": "Laden",
"statusDisconnected": "Getrennt",
@@ -74,8 +74,7 @@
"updated": "Aktualisiert",
"copy": "Kopieren",
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
"toResolve": "Lösen",
"add": "Hinzufügen"
"toResolve": "Lösen"
},
"gallery": {
"galleryImageSize": "Bildgröße",
@@ -105,16 +104,11 @@
"dropToUpload": "$t(gallery.drop) zum hochladen",
"dropOrUpload": "$t(gallery.drop) oder hochladen",
"drop": "Ablegen",
"problemDeletingImages": "Problem beim Löschen der Bilder",
"bulkDownloadRequested": "Download vorbereiten",
"bulkDownloadRequestedDesc": "Dein Download wird vorbereitet. Dies kann ein paar Momente dauern.",
"bulkDownloadRequestFailed": "Problem beim Download vorbereiten",
"bulkDownloadFailed": "Download fehlgeschlagen",
"alwaysShowImageSizeBadge": "Zeige immer Bilder Größe Abzeichen"
"problemDeletingImages": "Problem beim Löschen der Bilder"
},
"hotkeys": {
"keyboardShortcuts": "Tastenkürzel",
"appHotkeys": "App",
"appHotkeys": "App-Tastenkombinationen",
"generalHotkeys": "Allgemein",
"galleryHotkeys": "Galerie",
"unifiedCanvasHotkeys": "Leinwand",
@@ -763,9 +757,7 @@
"scheduler": "Planer",
"noRecallParameters": "Es wurden keine Parameter zum Abrufen gefunden",
"recallParameters": "Parameter wiederherstellen",
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
"allPrompts": "Alle Prompts",
"imageDimensions": "Bilder Auslösungen"
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
},
"popovers": {
"noiseUseCPU": {
@@ -1076,10 +1068,5 @@
},
"dynamicPrompts": {
"showDynamicPrompts": "Dynamische Prompts anzeigen"
},
"prompt": {
"noMatchingTriggers": "Keine passenden Auslöser",
"addPromptTrigger": "Auslöse Text hinzufügen",
"compatibleEmbeddings": "Kompatible Einbettungen"
}
}

View File

@@ -217,7 +217,6 @@
"saveControlImage": "Save Control Image",
"scribble": "scribble",
"selectModel": "Select a model",
"selectCLIPVisionModel": "Select a CLIP Vision model",
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
"showAdvanced": "Show Advanced",
"small": "Small",
@@ -656,7 +655,6 @@
"install": "Install",
"installAll": "Install All",
"installRepo": "Install Repo",
"ipAdapters": "IP Adapters",
"load": "Load",
"localOnly": "local only",
"manual": "Manual",

View File

@@ -73,8 +73,7 @@
"ai": "ia",
"file": "File",
"toResolve": "Da risolvere",
"add": "Aggiungi",
"loglevel": "Livello di log"
"add": "Aggiungi"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -935,9 +934,7 @@
"base": "Base",
"lineart": "Linea",
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
"mediapipeFace": "Mediapipe Volto",
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))"
"mediapipeFace": "Mediapipe Volto"
},
"queue": {
"queueFront": "Aggiungi all'inizio della coda",
@@ -1493,8 +1490,7 @@
"title": "Generazione"
},
"advanced": {
"title": "Avanzate",
"options": "Opzioni $t(accordions.advanced.title)"
"title": "Avanzate"
},
"image": {
"title": "Immagine"

View File

@@ -75,8 +75,7 @@
"copy": "Копировать",
"localSystem": "Локальная система",
"aboutDesc": "Используя Invoke для работы? Проверьте это:",
"add": "Добавить",
"loglevel": "Уровень логов"
"add": "Добавить"
},
"gallery": {
"galleryImageSize": "Размер изображений",
@@ -1506,8 +1505,7 @@
"title": "Генерация"
},
"advanced": {
"title": "Расширенные",
"options": "Опции $t(accordions.advanced.title)"
"title": "Расширенные"
},
"image": {
"title": "Изображение"

View File

@@ -1,7 +1,7 @@
import type { UnknownAction } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { cloneDeep } from 'lodash-es';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions';
@@ -33,7 +33,7 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
}
if (socketGeneratorProgress.match(action)) {
const sanitized = deepClone(action);
const sanitized = cloneDeep(action);
if (sanitized.payload.data.progress_image) {
sanitized.payload.data.progress_image.dataURL = '<Progress image omitted>';
}

View File

@@ -43,7 +43,6 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
})
);
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
},
});

View File

@@ -1,5 +1,4 @@
import { deepClone } from 'common/util/deepClone';
import { merge } from 'lodash-es';
import { cloneDeep, merge } from 'lodash-es';
import { ClickScrollPlugin, OverlayScrollbars } from 'overlayscrollbars';
import type { UseOverlayScrollbarsParams } from 'overlayscrollbars-react';
@@ -23,7 +22,7 @@ export const getOverlayScrollbarsParams = (
overflowX: 'hidden' | 'scroll' = 'hidden',
overflowY: 'hidden' | 'scroll' = 'scroll'
) => {
const params = deepClone(overlayScrollbarsParams);
const params = cloneDeep(overlayScrollbarsParams);
merge(params, { options: { overflow: { y: overflowY, x: overflowX } } });
return params;
};

View File

@@ -1,15 +0,0 @@
import rfdc from 'rfdc';
const _rfdc = rfdc();
/**
* Deep-clones an object using Really Fast Deep Clone.
* This is the fastest deep clone library on Chrome, but not the fastest on FF. Still, it's much faster than lodash
* and structuredClone, so it's the best all-around choice.
*
* Simple Benchmark: https://www.measurethat.net/Benchmarks/Show/30358/0/lodash-clonedeep-vs-jsonparsejsonstringify-vs-recursive
* Repo: https://github.com/davidmarkclements/rfdc
*
* @param obj The object to deep-clone
* @returns The cloned object
*/
export const deepClone = <T>(obj: T): T => _rfdc(obj);

View File

@@ -1,7 +1,6 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
import calculateCoordinates from 'features/canvas/util/calculateCoordinates';
import calculateScale from 'features/canvas/util/calculateScale';
@@ -14,7 +13,7 @@ import { modelChanged } from 'features/parameters/store/generationSlice';
import type { PayloadActionWithOptimalDimension } from 'features/parameters/store/types';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect, Vector2d } from 'konva/lib/types';
import { clamp } from 'lodash-es';
import { clamp, cloneDeep } from 'lodash-es';
import type { RgbaColor } from 'react-colorful';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO } from 'services/api/types';
@@ -37,7 +36,7 @@ import { CANVAS_GRID_SIZE_FINE } from './constants';
/**
* The maximum history length to keep in the past/future layer states.
*/
const MAX_HISTORY = 100;
const MAX_HISTORY = 128;
const initialLayerState: CanvasLayerState = {
objects: [],
@@ -122,7 +121,7 @@ export const canvasSlice = createSlice({
state.brushSize = action.payload;
},
clearMask: (state) => {
pushToPrevLayerStates(state);
state.pastLayerStates.push(cloneDeep(state.layerState));
state.layerState.objects = state.layerState.objects.filter((obj) => !isCanvasMaskLine(obj));
state.futureLayerStates = [];
state.shouldPreserveMaskedArea = false;
@@ -164,10 +163,10 @@ export const canvasSlice = createSlice({
state.boundingBoxDimensions = newBoundingBoxDimensions;
state.boundingBoxCoordinates = newBoundingBoxCoordinates;
pushToPrevLayerStates(state);
state.pastLayerStates.push(cloneDeep(state.layerState));
state.layerState = {
...deepClone(initialLayerState),
...cloneDeep(initialLayerState),
objects: [
{
kind: 'image',
@@ -262,7 +261,11 @@ export const canvasSlice = createSlice({
return;
}
pushToPrevLayerStates(state);
state.pastLayerStates.push(cloneDeep(state.layerState));
if (state.pastLayerStates.length > MAX_HISTORY) {
state.pastLayerStates.shift();
}
state.layerState.stagingArea.images.push({
kind: 'image',
@@ -276,9 +279,13 @@ export const canvasSlice = createSlice({
state.futureLayerStates = [];
},
discardStagedImages: (state) => {
pushToPrevLayerStates(state);
state.pastLayerStates.push(cloneDeep(state.layerState));
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
if (state.pastLayerStates.length > MAX_HISTORY) {
state.pastLayerStates.shift();
}
state.layerState.stagingArea = cloneDeep(cloneDeep(initialLayerState)).stagingArea;
state.futureLayerStates = [];
state.shouldShowStagingOutline = true;
@@ -287,7 +294,11 @@ export const canvasSlice = createSlice({
},
discardStagedImage: (state) => {
const { images, selectedImageIndex } = state.layerState.stagingArea;
pushToPrevLayerStates(state);
state.pastLayerStates.push(cloneDeep(state.layerState));
if (state.pastLayerStates.length > MAX_HISTORY) {
state.pastLayerStates.shift();
}
if (!images.length) {
return;
@@ -309,7 +320,11 @@ export const canvasSlice = createSlice({
addFillRect: (state) => {
const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = state;
pushToPrevLayerStates(state);
state.pastLayerStates.push(cloneDeep(state.layerState));
if (state.pastLayerStates.length > MAX_HISTORY) {
state.pastLayerStates.shift();
}
state.layerState.objects.push({
kind: 'fillRect',
@@ -324,7 +339,11 @@ export const canvasSlice = createSlice({
addEraseRect: (state) => {
const { boundingBoxCoordinates, boundingBoxDimensions } = state;
pushToPrevLayerStates(state);
state.pastLayerStates.push(cloneDeep(state.layerState));
if (state.pastLayerStates.length > MAX_HISTORY) {
state.pastLayerStates.shift();
}
state.layerState.objects.push({
kind: 'eraseRect',
@@ -348,7 +367,11 @@ export const canvasSlice = createSlice({
// set & then spread this to only conditionally add the "color" key
const newColor = layer === 'base' && tool === 'brush' ? { color: brushColor } : {};
pushToPrevLayerStates(state);
state.pastLayerStates.push(cloneDeep(state.layerState));
if (state.pastLayerStates.length > MAX_HISTORY) {
state.pastLayerStates.shift();
}
const newLine: CanvasMaskLine | CanvasBaseLine = {
kind: 'line',
@@ -386,7 +409,11 @@ export const canvasSlice = createSlice({
return;
}
pushToFutureLayerStates(state);
state.futureLayerStates.unshift(cloneDeep(state.layerState));
if (state.futureLayerStates.length > MAX_HISTORY) {
state.futureLayerStates.pop();
}
state.layerState = targetState;
},
@@ -397,7 +424,11 @@ export const canvasSlice = createSlice({
return;
}
pushToPrevLayerStates(state);
state.pastLayerStates.push(cloneDeep(state.layerState));
if (state.pastLayerStates.length > MAX_HISTORY) {
state.pastLayerStates.shift();
}
state.layerState = targetState;
},
@@ -414,8 +445,8 @@ export const canvasSlice = createSlice({
state.shouldShowIntermediates = action.payload;
},
resetCanvas: (state) => {
pushToPrevLayerStates(state);
state.layerState = deepClone(initialLayerState);
state.pastLayerStates.push(cloneDeep(state.layerState));
state.layerState = cloneDeep(initialLayerState);
state.futureLayerStates = [];
state.batchIds = [];
state.boundingBoxCoordinates = {
@@ -509,7 +540,11 @@ export const canvasSlice = createSlice({
const { images, selectedImageIndex } = state.layerState.stagingArea;
pushToPrevLayerStates(state);
state.pastLayerStates.push(cloneDeep(state.layerState));
if (state.pastLayerStates.length > MAX_HISTORY) {
state.pastLayerStates.shift();
}
const imageToCommit = images[selectedImageIndex];
@@ -518,7 +553,7 @@ export const canvasSlice = createSlice({
...imageToCommit,
});
}
state.layerState.stagingArea = deepClone(initialLayerState.stagingArea);
state.layerState.stagingArea = cloneDeep(initialLayerState).stagingArea;
state.futureLayerStates = [];
state.shouldShowStagingOutline = true;
@@ -588,7 +623,7 @@ export const canvasSlice = createSlice({
};
},
setMergedCanvas: (state, action: PayloadAction<CanvasImage>) => {
pushToPrevLayerStates(state);
state.pastLayerStates.push(cloneDeep(state.layerState));
state.futureLayerStates = [];
@@ -708,17 +743,3 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
migrate: migrateCanvasState,
persistDenylist: [],
};
const pushToPrevLayerStates = (state: CanvasState) => {
state.pastLayerStates.push(deepClone(state.layerState));
if (state.pastLayerStates.length > MAX_HISTORY) {
state.pastLayerStates = state.pastLayerStates.slice(-MAX_HISTORY);
}
};
const pushToFutureLayerStates = (state: CanvasState) => {
state.futureLayerStates.unshift(deepClone(state.layerState));
if (state.futureLayerStates.length > MAX_HISTORY) {
state.futureLayerStates = state.futureLayerStates.slice(0, MAX_HISTORY);
}
};

View File

@@ -1,18 +1,12 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useControlAdapterCLIPVisionModel } from 'features/controlAdapters/hooks/useControlAdapterCLIPVisionModel';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import {
controlAdapterCLIPVisionModelChanged,
controlAdapterModelChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import type { CLIPVisionModel } from 'features/controlAdapters/store/types';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -35,7 +29,6 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const { modelConfig } = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const currentCLIPVisionModel = useControlAdapterCLIPVisionModel(id);
const mainModel = useAppSelector(selectMainModel);
const { t } = useTranslation();
@@ -56,16 +49,6 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
[dispatch, id]
);
const onCLIPVisionModelChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v?.value) {
return;
}
dispatch(controlAdapterCLIPVisionModelChanged({ id, clipVisionModel: v.value as CLIPVisionModel }));
},
[dispatch, id]
);
const selectedModel = useMemo(
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
[controlAdapterType, modelConfig]
@@ -88,51 +71,18 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
isLoading,
});
const clipVisionOptions = useMemo<ComboboxOption[]>(
() => [
{ label: 'ViT-H', value: 'ViT-H' },
{ label: 'ViT-G', value: 'ViT-G' },
],
[]
);
const clipVisionModel = useMemo(
() => clipVisionOptions.find((o) => o.value === currentCLIPVisionModel),
[clipVisionOptions, currentCLIPVisionModel]
);
return (
<Flex sx={{ gap: 2 }}>
<Tooltip label={value?.description}>
<FormControl
isDisabled={!isEnabled}
isInvalid={!value || mainModel?.base !== modelConfig?.base}
sx={{ width: '100%' }}
>
<Combobox
options={options}
placeholder={t('controlnet.selectModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
{modelConfig?.type === 'ip_adapter' && modelConfig.format === 'checkpoint' && (
<FormControl
isDisabled={!isEnabled}
isInvalid={!value || mainModel?.base !== modelConfig?.base}
sx={{ width: 'max-content', minWidth: 28 }}
>
<Combobox
options={clipVisionOptions}
placeholder={t('controlnet.selectCLIPVisionModel')}
value={clipVisionModel}
onChange={onCLIPVisionModelChange}
/>
</FormControl>
)}
</Flex>
<Tooltip label={value?.description}>
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base}>
<Combobox
options={options}
placeholder={t('controlnet.selectModel')}
value={value}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
);
};

View File

@@ -1,24 +0,0 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectControlAdapterById,
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useMemo } from 'react';
export const useControlAdapterCLIPVisionModel = (id: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
const cn = selectControlAdapterById(controlAdapters, id);
if (cn && cn?.type === 'ip_adapter') {
return cn.clipVisionModel;
}
}),
[id]
);
const clipVisionModel = useAppSelector(selector);
return clipVisionModel;
};

View File

@@ -2,11 +2,10 @@ import type { PayloadAction, Update } from '@reduxjs/toolkit';
import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge, uniq } from 'lodash-es';
import { cloneDeep, merge, uniq } from 'lodash-es';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { socketInvocationError } from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
@@ -14,7 +13,6 @@ import { v4 as uuidv4 } from 'uuid';
import { controlAdapterImageProcessed } from './actions';
import { CONTROLNET_PROCESSORS } from './constants';
import type {
CLIPVisionModel,
ControlAdapterConfig,
ControlAdapterProcessorType,
ControlAdaptersState,
@@ -116,7 +114,7 @@ export const controlAdaptersSlice = createSlice({
if (!controlAdapter) {
return;
}
const newControlAdapter = merge(deepClone(controlAdapter), {
const newControlAdapter = merge(cloneDeep(controlAdapter), {
id: newId,
isEnabled: true,
});
@@ -245,13 +243,6 @@ export const controlAdaptersSlice = createSlice({
}
caAdapter.updateOne(state, { id, changes: { controlMode } });
},
controlAdapterCLIPVisionModelChanged: (
state,
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
) => {
const { id, clipVisionModel } = action.payload;
caAdapter.updateOne(state, { id, changes: { clipVisionModel } });
},
controlAdapterResizeModeChanged: (
state,
action: PayloadAction<{
@@ -279,7 +270,7 @@ export const controlAdaptersSlice = createSlice({
return;
}
const processorNode = merge(deepClone(cn.processorNode), params);
const processorNode = merge(cloneDeep(cn.processorNode), params);
caAdapter.updateOne(state, {
id,
@@ -302,7 +293,7 @@ export const controlAdaptersSlice = createSlice({
return;
}
const processorNode = deepClone(
const processorNode = cloneDeep(
CONTROLNET_PROCESSORS[processorType].buildDefaults(cn.model?.base)
) as RequiredControlAdapterProcessorNode;
@@ -342,7 +333,7 @@ export const controlAdaptersSlice = createSlice({
caAdapter.updateOne(state, update);
},
controlAdaptersReset: () => {
return deepClone(initialControlAdaptersState);
return cloneDeep(initialControlAdaptersState);
},
pendingControlImagesCleared: (state) => {
state.pendingControlImages = [];
@@ -389,7 +380,6 @@ export const {
controlAdapterProcessedImageChanged,
controlAdapterIsEnabledChanged,
controlAdapterModelChanged,
controlAdapterCLIPVisionModelChanged,
controlAdapterWeightChanged,
controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged,
@@ -416,7 +406,7 @@ const migrateControlAdaptersState = (state: any): any => {
state._version = 1;
}
if (state._version === 1) {
state = deepClone(initialControlAdaptersState);
state = cloneDeep(initialControlAdaptersState);
}
return state;
};

View File

@@ -243,15 +243,12 @@ export type T2IAdapterConfig = {
shouldAutoConfig: boolean;
};
export type CLIPVisionModel = 'ViT-H' | 'ViT-G';
export type IPAdapterConfig = {
type: 'ip_adapter';
id: string;
isEnabled: boolean;
controlImage: string | null;
model: ParameterIPAdapterModel | null;
clipVisionModel: CLIPVisionModel;
weight: number;
beginStepPct: number;
endStepPct: number;

View File

@@ -1,4 +1,3 @@
import { deepClone } from 'common/util/deepClone';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type {
ControlAdapterConfig,
@@ -8,7 +7,7 @@ import type {
RequiredCannyImageProcessorInvocation,
T2IAdapterConfig,
} from 'features/controlAdapters/store/types';
import { merge } from 'lodash-es';
import { cloneDeep, merge } from 'lodash-es';
export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
type: 'controlnet',
@@ -46,7 +45,6 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
isEnabled: true,
controlImage: null,
model: null,
clipVisionModel: 'ViT-H',
weight: 1,
beginStepPct: 0,
endStepPct: 1,
@@ -59,11 +57,11 @@ export const buildControlAdapter = (
): ControlAdapterConfig => {
switch (type) {
case 'controlnet':
return merge(deepClone(initialControlNet), { id, ...overrides });
return merge(cloneDeep(initialControlNet), { id, ...overrides });
case 't2i_adapter':
return merge(deepClone(initialT2IAdapter), { id, ...overrides });
return merge(cloneDeep(initialT2IAdapter), { id, ...overrides });
case 'ip_adapter':
return merge(deepClone(initialIPAdapter), { id, ...overrides });
return merge(cloneDeep(initialIPAdapter), { id, ...overrides });
default:
throw new Error(`Unknown control adapter type: ${type}`);
}

View File

@@ -1,9 +1,9 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import { cloneDeep } from 'lodash-es';
import type { LoRAModelConfig } from 'services/api/types';
export type LoRA = {
@@ -58,7 +58,7 @@ export const loraSlice = createSlice({
}
lora.isEnabled = isEnabled;
},
lorasReset: () => deepClone(initialLoraState),
lorasReset: () => cloneDeep(initialLoraState),
},
});
@@ -74,7 +74,7 @@ const migrateLoRAState = (state: any): any => {
}
if (state._version === 1) {
// Model type has changed, so we need to reset the state - too risky to migrate
state = deepClone(initialLoraState);
state = cloneDeep(initialLoraState);
}
return state;
};

View File

@@ -372,7 +372,6 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
type: 'ip_adapter',
isEnabled: true,
model: zModelIdentifierField.parse(ipAdapterModel),
clipVisionModel: 'ViT-H',
controlImage: image?.image_name ?? null,
weight: weight ?? initialIPAdapter.weight,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,

View File

@@ -87,10 +87,6 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
}, [installJob.source]);
const progressValue = useMemo(() => {
if (installJob.status === 'completed' || installJob.status === 'error' || installJob.status === 'cancelled') {
return 100;
}
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
return null;
}
@@ -100,7 +96,7 @@ export const ModelInstallQueueItem = (props: ModelListItemProps) => {
}
return (installJob.bytes / installJob.total_bytes) * 100;
}, [installJob.bytes, installJob.status, installJob.total_bytes]);
}, [installJob.bytes, installJob.total_bytes]);
return (
<Flex gap={3} w="full" alignItems="center">

View File

@@ -1,19 +1,48 @@
import { Badge, Box, Flex, IconButton, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import type { ScanFolderResponse } from 'services/api/endpoints/models';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type Props = {
result: ScanFolderResponse[number];
installModel: (source: string) => void;
};
export const ScanModelResultItem = ({ result, installModel }: Props) => {
export const ScanModelResultItem = ({ result }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleInstall = useCallback(() => {
installModel(result.path);
}, [installModel, result]);
const [installModel] = useInstallModelMutation();
const handleQuickAdd = useCallback(() => {
installModel({ source: result.path })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [installModel, result, dispatch, t]);
return (
<Flex alignItems="center" justifyContent="space-between" w="100%" gap={3}>
@@ -25,7 +54,7 @@ export const ScanModelResultItem = ({ result, installModel }: Props) => {
{result.is_installed ? (
<Badge>{t('common.installed')}</Badge>
) : (
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleInstall} size="sm" />
<IconButton aria-label={t('modelManager.install')} icon={<PiPlusBold />} onClick={handleQuickAdd} size="sm" />
)}
</Box>
</Flex>

View File

@@ -1,10 +1,7 @@
import {
Button,
Checkbox,
Divider,
Flex,
FormControl,
FormLabel,
Heading,
IconButton,
Input,
@@ -15,7 +12,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { ChangeEvent, ChangeEventHandler } from 'react';
import type { ChangeEventHandler } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
@@ -31,7 +28,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState('');
const dispatch = useAppDispatch();
const [inplace, setInplace] = useState(true);
const [installModel] = useInstallModelMutation();
const filteredResults = useMemo(() => {
@@ -45,10 +42,6 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
setSearchTerm(e.target.value.trim());
}, []);
const onChangeInplace = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setInplace(e.target.checked);
}, []);
const clearSearch = useCallback(() => {
setSearchTerm('');
}, []);
@@ -58,7 +51,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
if (result.is_installed) {
continue;
}
installModel({ source: result.path, inplace })
installModel({ source: result.path })
.unwrap()
.then((_) => {
dispatch(
@@ -83,37 +76,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
}
});
}
}, [filteredResults, installModel, inplace, dispatch, t]);
const handleInstallOne = useCallback(
(source: string) => {
installModel({ source, inplace })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
},
[installModel, inplace, dispatch, t]
);
}, [installModel, filteredResults, dispatch, t]);
return (
<>
@@ -122,10 +85,6 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
<Flex justifyContent="space-between" alignItems="center">
<Heading size="sm">{t('modelManager.scanResults')}</Heading>
<Flex alignItems="center" gap={3}>
<FormControl w="min-content">
<FormLabel m={0}>{t('modelManager.inplaceInstall')}</FormLabel>
<Checkbox isChecked={inplace} onChange={onChangeInplace} size="md" />
</FormControl>
<Button size="sm" onClick={handleAddAll} isDisabled={filteredResults.length === 0}>
{t('modelManager.installAll')}
</Button>
@@ -157,7 +116,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
<ScrollableContent>
<Flex flexDir="column" gap={3}>
{filteredResults.map((result) => (
<ScanModelResultItem key={result.path} result={result} installModel={handleInstallOne} />
<ScanModelResultItem key={result.path} result={result} />
))}
</Flex>
</ScrollableContent>

View File

@@ -90,13 +90,11 @@ const ModelListItem = (props: ModelListItemProps) => {
cursor="pointer"
onClick={handleSelectModel}
>
<Flex gap={2} w="full" h="full" minW={0}>
<Flex gap={2} w="full" h="full">
<ModelImage image_url={model.cover_image} />
<Flex gap={1} alignItems="flex-start" flexDir="column" w="full" minW={0}>
<Flex gap={1} alignItems="flex-start" flexDir="column" w="full">
<Flex gap={2} w="full" alignItems="flex-start">
<Text fontWeight="semibold" noOfLines={1} wordBreak="break-all">
{model.name}
</Text>
<Text fontWeight="semibold">{model.name}</Text>
<Spacer />
</Flex>
<Text variant="subtext" noOfLines={1}>

View File

@@ -87,9 +87,9 @@ export const Model = () => {
<Flex flexDir="column" gap={4}>
<Flex alignItems="flex-start" gap={4}>
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
<Flex flexDir="column" gap={1} flexGrow={1} minW={0}>
<Flex flexDir="column" gap={1} flexGrow={1}>
<Flex gap={2}>
<Heading as="h2" fontSize="lg" noOfLines={1} wordBreak="break-all">
<Heading as="h2" fontSize="lg">
{data.name}
</Heading>
<Spacer />
@@ -114,7 +114,7 @@ export const Model = () => {
)}
</Flex>
{data.source && (
<Text variant="subtext" noOfLines={1} wordBreak="break-all">
<Text variant="subtext">
{t('modelManager.source')}: {data?.source}
</Text>
)}

View File

@@ -9,9 +9,7 @@ export const ModelAttrView = ({ label, value }: Props) => {
return (
<FormControl flexDir="column" alignItems="flex-start" gap={0}>
<FormLabel>{label}</FormLabel>
<Text fontSize="md" noOfLines={1} wordBreak="break-all">
{value || '-'}
</Text>
<Text fontSize="md">{value || '-'}</Text>
</FormControl>
);
};

View File

@@ -53,7 +53,7 @@ export const ModelView = () => {
</>
)}
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
{data.type === 'ip_adapter' && (
<Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
</Flex>

View File

@@ -1,7 +1,6 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
@@ -45,7 +44,7 @@ import {
} from 'features/nodes/types/field';
import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es';
import { cloneDeep, forEach } from 'lodash-es';
import type {
Connection,
Edge,
@@ -572,23 +571,8 @@ export const nodesSlice = createSlice({
);
},
selectionCopied: (state) => {
const nodesToCopy: AnyNode[] = [];
const edgesToCopy: Edge[] = [];
for (const node of state.nodes) {
if (node.selected) {
nodesToCopy.push(deepClone(node));
}
}
for (const edge of state.edges) {
if (edge.selected) {
edgesToCopy.push(deepClone(edge));
}
}
state.nodesToCopy = nodesToCopy;
state.edgesToCopy = edgesToCopy;
state.nodesToCopy = state.nodes.filter((n) => n.selected).map(cloneDeep);
state.edgesToCopy = state.edges.filter((e) => e.selected).map(cloneDeep);
if (state.nodesToCopy.length > 0) {
const averagePosition = { x: 0, y: 0 };
@@ -610,21 +594,11 @@ export const nodesSlice = createSlice({
},
selectionPasted: (state, action: PayloadAction<{ cursorPosition?: XYPosition }>) => {
const { cursorPosition } = action.payload;
const newNodes: AnyNode[] = [];
for (const node of state.nodesToCopy) {
newNodes.push(deepClone(node));
}
const newNodes = state.nodesToCopy.map(cloneDeep);
const oldNodeIds = newNodes.map((n) => n.data.id);
const newEdges: Edge[] = [];
for (const edge of state.edgesToCopy) {
if (oldNodeIds.includes(edge.source) && oldNodeIds.includes(edge.target)) {
newEdges.push(deepClone(edge));
}
}
const newEdges = state.edgesToCopy
.filter((e) => oldNodeIds.includes(e.source) && oldNodeIds.includes(e.target))
.map(cloneDeep);
newEdges.forEach((e) => (e.selected = true));

View File

@@ -1,7 +1,6 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice';
import type {
@@ -12,7 +11,7 @@ import type {
import type { FieldIdentifier } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow';
import { isEqual, omit, uniqBy } from 'lodash-es';
import { cloneDeep, isEqual, omit, uniqBy } from 'lodash-es';
const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
name: '',
@@ -132,8 +131,8 @@ export const workflowSlice = createSlice({
});
return {
...deepClone(initialWorkflowState),
...deepClone(workflowExtra),
...cloneDeep(initialWorkflowState),
...cloneDeep(workflowExtra),
originalExposedFieldValues,
mode: state.mode,
};
@@ -145,7 +144,7 @@ export const workflowSlice = createSlice({
});
});
builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState));
builder.addCase(nodeEditorReset, () => cloneDeep(initialWorkflowState));
builder.addCase(nodesChanged, (state, action) => {
// Not all changes to nodes should result in the workflow being marked touched

View File

@@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
if (!ipAdapter.model) {
return;
}
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
const { id, weight, model, beginStepPct, endStepPct, controlImage } = ipAdapter;
assert(controlImage, 'IP Adapter image is required');
@@ -58,7 +58,6 @@ export const addIPAdapterToLinearGraph = async (
is_intermediate: true,
weight: weight,
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: {
@@ -84,7 +83,7 @@ export const addIPAdapterToLinearGraph = async (
};
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
const { controlImage, beginStepPct, endStepPct, model, weight } = ipAdapter;
assert(model, 'IP Adapter model is required');
@@ -100,7 +99,6 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
return {
ip_adapter_model: model,
clip_vision_model: clipVisionModel,
weight,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,

View File

@@ -1,9 +1,8 @@
import { deepClone } from 'common/util/deepClone';
import { satisfies } from 'compare-versions';
import { NodeUpdateError } from 'features/nodes/types/error';
import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation';
import { zParsedSemver } from 'features/nodes/types/semver';
import { defaultsDeep, keys, pick } from 'lodash-es';
import { cloneDeep, defaultsDeep, keys, pick } from 'lodash-es';
import { buildInvocationNode } from './buildInvocationNode';
@@ -51,7 +50,7 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate):
// The updateability of a node, via semver comparison, relies on the this kind of recursive merge
// being valid. We rely on the template's major version to be majorly incremented if this kind of
// merge would result in an invalid node.
const clone = deepClone(node);
const clone = cloneDeep(node);
clone.data.version = template.version;
defaultsDeep(clone, defaults); // mutates!

View File

@@ -1,12 +1,11 @@
import { logger } from 'app/logging/logger';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import type { NodesState, WorkflowsState } from 'features/nodes/store/types';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { zWorkflowV3 } from 'features/nodes/types/workflow';
import i18n from 'i18n';
import { pick } from 'lodash-es';
import { cloneDeep, pick } from 'lodash-es';
import { fromZodError } from 'zod-validation-error';
export type BuildWorkflowArg = {
@@ -31,7 +30,7 @@ const workflowKeys = [
type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV3;
export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 => {
const clonedWorkflow = pick(deepClone(workflow), workflowKeys);
const clonedWorkflow = pick(cloneDeep(workflow), workflowKeys);
const newWorkflow: WorkflowV3 = {
...clonedWorkflow,
@@ -44,14 +43,14 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo
newWorkflow.nodes.push({
id: node.id,
type: node.type,
data: deepClone(node.data),
data: cloneDeep(node.data),
position: { ...node.position },
});
} else if (isNotesNode(node) && node.type) {
newWorkflow.nodes.push({
id: node.id,
type: node.type,
data: deepClone(node.data),
data: cloneDeep(node.data),
position: { ...node.position },
});
}

View File

@@ -1,5 +1,4 @@
import { $store } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import type { FieldType } from 'features/nodes/types/field';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
@@ -12,7 +11,7 @@ import { zWorkflowV2 } from 'features/nodes/types/v2/workflow';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { zWorkflowV3 } from 'features/nodes/types/workflow';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
import { cloneDeep, forEach } from 'lodash-es';
import { z } from 'zod';
/**
@@ -90,7 +89,7 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => {
throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion'));
}
let workflow = deepClone(data) as WorkflowV1 | WorkflowV2 | WorkflowV3;
let workflow = cloneDeep(data) as WorkflowV1 | WorkflowV2 | WorkflowV3;
if (workflow.meta.version === '1.0.0') {
const v1 = zWorkflowV1.parse(workflow);

View File

@@ -280,7 +280,6 @@ const migrateGenerationState = (state: any): GenerationState => {
// The signature of the model has changed, so we need to reset it
state._version = 2;
state.model = null;
state.canvasCoherenceMode = initialGenerationState.canvasCoherenceMode;
}
return state;
};

View File

@@ -61,7 +61,7 @@ export const AdvancedSettingsAccordion = memo(() => {
return (
<StandaloneAccordion label={t('accordions.advanced.title')} badges={badges} isOpen={isOpen} onToggle={onToggle}>
<Flex gap={4} alignItems="center" p={4} flexDir="column" data-testid="advanced-settings-accordion">
<Flex gap={4} alignItems="center" p={4} flexDir="column">
<Flex gap={4} w="full">
<ParamVAEModelSelect />
<ParamVAEPrecision />

View File

@@ -77,7 +77,7 @@ export const ControlSettingsAccordion: React.FC = memo(() => {
return (
<StandaloneAccordion label={t('accordions.control.title')} badges={badges} isOpen={isOpen} onToggle={onToggle}>
<Flex gap={2} p={4} flexDir="column" data-testid="control-accordion">
<Flex gap={2} p={4} flexDir="column">
<ButtonGroup size="sm" w="full" justifyContent="space-between" variant="ghost" isAttached={false}>
<Button
tooltip={t('controlnet.addControlNet')}

View File

@@ -53,7 +53,7 @@ export const GenerationSettingsAccordion = memo(() => {
isOpen={isOpenAccordion}
onToggle={onToggleAccordion}
>
<Box px={4} pt={4} data-testid="generation-accordion">
<Box px={4} pt={4}>
<Flex gap={4} flexDir="column">
<Flex gap={4} alignItems="center">
<ParamMainModelSelect />

View File

@@ -83,7 +83,7 @@ export const ImageSettingsAccordion = memo(() => {
isOpen={isOpenAccordion}
onToggle={onToggleAccordion}
>
<Flex px={4} pt={4} w="full" h="full" flexDir="column" data-testid="image-settings-accordion">
<Flex px={4} pt={4} w="full" h="full" flexDir="column">
{activeTabName === 'unifiedCanvas' ? <ImageSizeCanvas /> : <ImageSizeLinear />}
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
<Flex gap={4} pb={4} flexDir="column">

View File

@@ -195,7 +195,6 @@ export const modelsApi = api.injectEndpoints({
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
};
},
providesTags: [{ type: 'ModelScanFolderResults', id: LIST_TAG }],
}),
getHuggingFaceModels: build.query<GetHuggingFaceModelsResponse, string>({
query: (hugging_face_repo) => {

View File

@@ -192,7 +192,7 @@ export const queueApi = api.injectEndpoints({
{ batch_id: string }
>({
query: ({ batch_id }) => ({
url: buildQueueUrl(`b/${batch_id}/status`),
url: buildQueueUrl(`/b/${batch_id}/status`),
method: 'GET',
}),
providesTags: (result) => {

View File

@@ -29,7 +29,6 @@ const tagTypes = [
'InvocationCacheStatus',
'ModelConfig',
'ModelInstalls',
'ModelScanFolderResults',
'T2IAdapterModel',
'MainModel',
'VaeModel',

File diff suppressed because one or more lines are too long

View File

@@ -46,7 +46,7 @@ export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
// TODO(MM2): Can we rename this from Vae -> VAE
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
type DiffusersModelConfig = S['MainDiffusersConfig'];

View File

@@ -1 +1 @@
__version__ = "4.0.2"
__version__ = "4.0.0rc6"

View File

@@ -44,6 +44,7 @@ dependencies = [
"onnx==1.15.0",
"onnxruntime==1.16.3",
"opencv-python==4.9.0.80",
"peft==0.9.0",
"pytorch-lightning==2.1.3",
"safetensors==0.4.2",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
@@ -73,7 +74,7 @@ dependencies = [
"easing-functions",
"einops",
"facexlib",
"matplotlib", # needed for plotting of Penner easing functions
"matplotlib", # needed for plotting of Penner easing functions
"npyscreen",
"omegaconf",
"picklescan",

View File

@@ -43,7 +43,8 @@ def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_fil
assert model_record is not None
assert model_record.name == "test_embedding"
assert model_record.type == ModelType.TextualInversion
assert Path(model_record.path) == embedding_file
assert model_record.path.endswith(embedding_file.as_posix())
assert Path(model_record.path).is_absolute()
assert Path(model_record.path).exists()
assert model_record.base == BaseModelType("sd-1")
assert model_record.description is not None
@@ -76,7 +77,8 @@ def test_install(
key = mm2_installer.install_path(embedding_file)
model_record = store.get_model(key)
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
assert (mm2_app_config.models_path / model_record.path).exists()
assert Path(model_record.path).is_absolute()
assert Path(model_record.path).exists()
assert model_record.source == embedding_file.as_posix()
@@ -87,11 +89,9 @@ def test_rename(
key = mm2_installer.install_path(embedding_file)
model_record = store.get_model(key)
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
store.update_model(key, ModelRecordChanges(name="new model name", base=BaseModelType("sd-2")))
store.update_model(key, ModelRecordChanges(name="new_name.safetensors", base=BaseModelType("sd-2")))
new_model_record = mm2_installer.sync_model_path(key)
# Renaming the model record shouldn't rename the file
assert new_model_record.name == "new model name"
assert new_model_record.path.endswith("sd-2/embedding/test_embedding.safetensors")
assert new_model_record.path.endswith("sd-2/embedding/new_name.safetensors")
@pytest.mark.parametrize(
@@ -147,7 +147,10 @@ def test_background_install(
model_record = mm2_installer.record_store.get_model(key)
assert model_record is not None
assert model_record.path.endswith(destination)
assert (mm2_app_config.models_path / model_record.path).exists()
assert Path(model_record.path).is_absolute()
assert Path(model_record.path).exists()
assert model_record.key != "<NOKEY>"
assert Path(model_record.path).exists()
# see if metadata was properly passed through
assert model_record.description == description
@@ -169,7 +172,7 @@ def test_not_inplace_install(
assert job is not None
assert job.config_out is not None
assert Path(job.config_out.path) != embedding_file
assert (mm2_app_config.models_path / job.config_out.path).exists()
assert Path(job.config_out.path).exists()
def test_inplace_install(
@@ -181,21 +184,16 @@ def test_inplace_install(
assert job is not None
assert job.config_out is not None
assert Path(job.config_out.path) == embedding_file
assert Path(job.config_out.path).exists()
def test_delete_install(
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
) -> None:
def test_delete_install(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
store = mm2_installer.record_store
key = mm2_installer.install_path(embedding_file)
model_record = store.get_model(key)
assert (mm2_app_config.models_path / model_record.path).exists()
assert Path(model_record.path).exists()
assert embedding_file.exists() # original should still be there after installation
mm2_installer.delete(key)
assert not (
mm2_app_config.models_path / model_record.path
).exists() # after deletion, installed copy should not exist
assert not Path(model_record.path).exists() # after deletion, installed copy should not exist
assert embedding_file.exists() # but original should still be there
with pytest.raises(UnknownModelException):
store.get_model(key)
@@ -234,7 +232,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
key = job.config_out.key
model_record = store.get_model(key)
assert (mm2_app_config.models_path / model_record.path).exists()
assert Path(model_record.path).exists()
assert len(bus.events) == 4
event_names = [x.event_name for x in bus.events]
@@ -263,7 +261,7 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
key = job.config_out.key
model_record = store.get_model(key)
assert (mm2_app_config.models_path / model_record.path).exists()
assert Path(model_record.path).exists()
assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers

View File

@@ -5,7 +5,7 @@
import pytest
import torch
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.lora_model_raw import LoRALayer, LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher

View File

@@ -98,32 +98,6 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
assert not hasattr(config, "esrgan")
@pytest.mark.parametrize(
"legacy_conf_dir,expected_value,expected_is_set",
[
# not set, expected value is the default value
("configs/stable-diffusion", Path("configs"), False),
# not set, expected value is the default value
("configs\\stable-diffusion", Path("configs"), False),
# set, best-effort resolution of the path
("partial_custom_path/stable-diffusion", Path("partial_custom_path"), True),
# set, exact path
("full/custom/path", Path("full/custom/path"), True),
],
)
def test_migrate_v3_legacy_conf_dir_defaults(
tmp_path: Path, patch_rootdir: None, legacy_conf_dir: str, expected_value: Path, expected_is_set: bool
):
"""Test reading configuration from a file."""
config_content = f"InvokeAI:\n Paths:\n legacy_conf_dir: {legacy_conf_dir}"
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(config_content)
config = load_and_migrate_config(temp_config_file)
assert config.legacy_conf_dir == expected_value
assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set
def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
"""Test the backup of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"

View File

@@ -250,32 +250,6 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
db.conn.close()
def test_migrator_backs_up_db(logger: Logger) -> None:
with TemporaryDirectory() as tempdir:
original_db_path = Path(tempdir) / "invokeai.db"
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
# Write some data to the db to test for successful backup
temp_cursor = db.conn.cursor()
temp_cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
db.conn.commit()
# Set up the migrator
migrator = SqliteMigrator(db=db)
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
migrator.run_migrations()
# Must manually close else we get an error on Windows
db.conn.close()
assert original_db_path.exists()
# We should have a backup file when we migrated a file db
assert migrator._backup_path
# Check that the test table exists as a proxy for successful backup
with closing(sqlite3.connect(migrator._backup_path)) as backup_db_conn:
backup_db_cursor = backup_db_conn.cursor()
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
assert backup_db_cursor.fetchone() is not None
def test_migrator_makes_no_changes_on_failed_migration(
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
) -> None: