Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6bfb4927c7 | ||
|
|
c15e9e23ca | ||
|
|
e1aa1ed6af | ||
|
|
4b68050c9b | ||
|
|
9e68a5c851 | ||
|
|
61a672cd81 | ||
|
|
c27a2e59da | ||
|
|
4e3f42e388 | ||
|
|
5d41157404 | ||
|
|
8db4ba252a | ||
|
|
6d9fb207f0 |
@@ -40,25 +40,6 @@ Follow the same steps to scan and import the missing models.
|
||||
- Check the `ram` setting in `invokeai.yaml`. This setting tells Invoke how much of your system RAM can be used to cache models. Having this too high or too low can slow things down. That said, it's generally safest to not set this at all and instead let Invoke manage it.
|
||||
- Check the `vram` setting in `invokeai.yaml`. This setting tells Invoke how much of your GPU VRAM can be used to cache models. Counter-intuitively, if this setting is too high, Invoke will need to do a lot of shuffling of models as it juggles the VRAM cache and the currently-loaded model. The default value of 0.25 is generally works well for GPUs without 16GB or more VRAM. Even on a 24GB card, the default works well.
|
||||
- Check that your generations are happening on your GPU (if you have one). InvokeAI will log what is being used for generation upon startup. If your GPU isn't used, re-install to ensure the correct versions of torch get installed.
|
||||
- If you are on Windows, you may have exceeded your GPU's VRAM capacity and are using slower [shared GPU memory](#shared-gpu-memory-windows). There's a guide to opt out of this behaviour in the linked FAQ entry.
|
||||
|
||||
## Shared GPU Memory (Windows)
|
||||
|
||||
!!! tip "Nvidia GPUs with driver 536.40"
|
||||
|
||||
This only applies to current Nvidia cards with driver 536.40 or later, released in June 2023.
|
||||
|
||||
When the GPU doesn't have enough VRAM for a task, Windows is able to allocate some of its CPU RAM to the GPU. This is much slower than VRAM, but it does allow the system to generate when it otherwise might no have enough VRAM.
|
||||
|
||||
When shared GPU memory is used, generation slows down dramatically - but at least it doesn't crash.
|
||||
|
||||
If you'd like to opt out of this behavior and instead get an error when you exceed your GPU's VRAM, follow [this guide from Nvidia](https://nvidia.custhelp.com/app/answers/detail/a_id/5490).
|
||||
|
||||
Here's how to get the python path required in the linked guide:
|
||||
|
||||
- Run `invoke.bat`.
|
||||
- Select option 2 for developer console.
|
||||
- At least one python path will be printed. Copy the path that includes your invoke installation directory (typically the first).
|
||||
|
||||
## Installer cannot find python (Windows)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
from invokeai.backend.util.logging import logging
|
||||
from invokeai.version import __version__
|
||||
@@ -100,7 +100,7 @@ async def get_app_deps() -> AppDependencyVersions:
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ["tile", "lama", "cv2", "color"] # TODO: add mosaic back
|
||||
infill_methods = ["tile", "lama", "cv2"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append("patchmatch")
|
||||
|
||||
|
||||
@@ -9,7 +9,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_model import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_model_patcher import LoraModelPatcher
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
@@ -80,7 +81,7 @@ class CompelInvocation(BaseInvocation):
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
LoraModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||
):
|
||||
@@ -181,7 +182,7 @@ class SDXLPromptInvocationBase:
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
LoraModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||
):
|
||||
|
||||
@@ -1,91 +1,154 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Literal, get_args
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
|
||||
from PIL import Image
|
||||
import math
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.infill_methods.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.infill_methods.lama import LaMA
|
||||
from invokeai.backend.image_util.infill_methods.mosaic import infill_mosaic
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch, infill_patchmatch
|
||||
from invokeai.backend.image_util.infill_methods.tile import infill_tile
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
def get_infill_methods():
|
||||
methods = Literal["tile", "color", "lama", "cv2"] # TODO: add mosaic back
|
||||
def infill_methods() -> list[str]:
|
||||
methods = ["tile", "solid", "lama", "cv2"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods = Literal["patchmatch", "tile", "color", "lama", "cv2"] # TODO: add mosaic back
|
||||
methods.insert(0, "patchmatch")
|
||||
return methods
|
||||
|
||||
|
||||
INFILL_METHODS = get_infill_methods()
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
|
||||
|
||||
class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Base class for invocations that preprocess images for Infilling"""
|
||||
def infill_lama(im: Image.Image) -> Image.Image:
|
||||
lama = LaMA()
|
||||
return lama(im)
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
@abstractmethod
|
||||
def infill(self, image: Image.Image) -> Image.Image:
|
||||
"""Infill the image with the specified method"""
|
||||
pass
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
|
||||
"""Process the image to have an alpha channel before being infilled"""
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
has_alpha = True if image.mode == "RGBA" else False
|
||||
return image, has_alpha
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
if not PatchMatch.patchmatch_available():
|
||||
return im
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Retrieve and process image to be infilled
|
||||
input_image, has_alpha = self.load_image(context)
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
|
||||
# If the input image has no alpha channel, return it
|
||||
if has_alpha is False:
|
||||
return ImageOutput.build(context.images.get_dto(self.image.image_name))
|
||||
|
||||
# Perform Infill action
|
||||
infilled_image = self.infill(input_image)
|
||||
def infill_cv2(im: Image.Image) -> Image.Image:
|
||||
return cv2_inpaint(im)
|
||||
|
||||
# Create ImageDTO for Infilled Image
|
||||
infilled_image_dto = context.images.save(image=infilled_image)
|
||||
|
||||
# Return Infilled Image
|
||||
return ImageOutput.build(infilled_image_dto)
|
||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
|
||||
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = tiles_mask > 0
|
||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum() # noqa: E712
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1, 2)
|
||||
st = tiles_all.reshape(
|
||||
(
|
||||
math.prod(tiles_all.shape[0:2]),
|
||||
math.prod(tiles_all.shape[2:4]),
|
||||
tiles_all.shape[4],
|
||||
)
|
||||
)
|
||||
si = Image.fromarray(st, mode="RGBA")
|
||||
|
||||
return si
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class InfillColorInvocation(InfillImageProcessorInvocation):
|
||||
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
color: ColorField = InputField(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
return infilled
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3")
|
||||
class InfillTileInvocation(InfillImageProcessorInvocation):
|
||||
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
@@ -94,74 +157,92 @@ class InfillTileInvocation(InfillImageProcessorInvocation):
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
|
||||
return output.infilled
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2"
|
||||
)
|
||||
class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name).convert("RGBA")
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
infill_image = image.copy()
|
||||
width = int(image.width / self.downscale)
|
||||
height = int(image.height / self.downscale)
|
||||
|
||||
infilled = image.resize(
|
||||
infill_image = infill_image.resize(
|
||||
(width, height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
infilled = infill_patchmatch(image)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(infill_image)
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
infilled = infilled.resize(
|
||||
(image.width, image.height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||
|
||||
return infilled
|
||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
lama = LaMA()
|
||||
return lama(image)
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Downloads the LaMa model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=context.config.get().models_path / "core/misc/lama/lama.pt",
|
||||
)
|
||||
|
||||
infilled = infill_lama(image.copy())
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class CV2InfillInvocation(InfillImageProcessorInvocation):
|
||||
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
return cv2_inpaint(image)
|
||||
|
||||
|
||||
# @invocation(
|
||||
# "infill_mosaic", title="Mosaic Infill", tags=["image", "inpaint", "outpaint"], category="inpaint", version="1.0.0"
|
||||
# )
|
||||
class MosaicInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image with a mosaic pattern drawing colors from the rest of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
tile_width: int = InputField(default=64, description="Width of the tile")
|
||||
tile_height: int = InputField(default=64, description="Height of the tile")
|
||||
min_color: ColorField = InputField(
|
||||
default=ColorField(r=0, g=0, b=0, a=255),
|
||||
description="The min threshold for color",
|
||||
)
|
||||
max_color: ColorField = InputField(
|
||||
default=ColorField(r=255, g=255, b=255, a=255),
|
||||
description="The max threshold for color",
|
||||
)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = infill_cv2(image.copy())
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -65,9 +65,9 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ui_order=-1,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
)
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
|
||||
clip_vision_model: Literal["auto", "ViT-H", "ViT-G"] = InputField(
|
||||
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
|
||||
default="ViT-H",
|
||||
default="auto",
|
||||
ui_order=2,
|
||||
)
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
@@ -96,9 +96,14 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
|
||||
|
||||
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
if self.clip_vision_model == "auto":
|
||||
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"You need to set the appropriate CLIP Vision model for checkpoint IP Adapter models."
|
||||
)
|
||||
else:
|
||||
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||
|
||||
|
||||
@@ -48,7 +48,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_model import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_model_patcher import LoraModelPatcher
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
@@ -730,7 +731,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
||||
unet_info as unet,
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
LoraModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
@@ -1254,7 +1255,7 @@ class IdealSizeInvocation(BaseInvocation):
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
unet_config = context.models.get_config(**self.unet.unet.model_dump())
|
||||
aspect = self.width / self.height
|
||||
dimension: float = 512
|
||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||
|
||||
@@ -5,7 +5,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
@@ -6,7 +6,8 @@ from typing import Optional, Type
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
ModelLoaderRegistry,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Initialization file for model manager service."""
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
|
||||
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
|
||||
@@ -8,7 +8,6 @@ from .model_manager_default import ModelManagerService, ModelManagerServiceBase
|
||||
__all__ = [
|
||||
"ModelManagerServiceBase",
|
||||
"ModelManagerService",
|
||||
"AnyModel",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
|
||||
@@ -80,7 +80,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
max_vram_cache_size=app_config.vram,
|
||||
lazy_offloading=app_config.lazy_offload,
|
||||
logger=logger,
|
||||
execution_device=execution_device,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Initialization file for invokeai.backend.image_util methods.
|
||||
"""
|
||||
|
||||
from .infill_methods.patchmatch import PatchMatch # noqa: F401
|
||||
from .patchmatch import PatchMatch # noqa: F401
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||
from .seamless import configure_model_padding # noqa: F401
|
||||
from .util import InitImageResizer, make_grid # noqa: F401
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def infill_mosaic(
|
||||
image: Image.Image,
|
||||
tile_shape: Tuple[int, int] = (64, 64),
|
||||
min_color: Tuple[int, int, int, int] = (0, 0, 0, 0),
|
||||
max_color: Tuple[int, int, int, int] = (255, 255, 255, 0),
|
||||
) -> Image.Image:
|
||||
"""
|
||||
image:PIL - A PIL Image
|
||||
tile_shape: Tuple[int,int] - Tile width & Tile Height
|
||||
min_color: Tuple[int,int,int] - RGB values for the lowest color to clip to (0-255)
|
||||
max_color: Tuple[int,int,int] - RGB values for the highest color to clip to (0-255)
|
||||
"""
|
||||
|
||||
np_image = np.array(image) # Convert image to np array
|
||||
alpha = np_image[:, :, 3] # Get the mask from the alpha channel of the image
|
||||
non_transparent_pixels = np_image[alpha != 0, :3] # List of non-transparent pixels
|
||||
|
||||
# Create color tiles to paste in the empty areas of the image
|
||||
tile_width, tile_height = tile_shape
|
||||
|
||||
# Clip the range of colors in the image to a particular spectrum only
|
||||
r_min, g_min, b_min, _ = min_color
|
||||
r_max, g_max, b_max, _ = max_color
|
||||
non_transparent_pixels[:, 0] = np.clip(non_transparent_pixels[:, 0], r_min, r_max)
|
||||
non_transparent_pixels[:, 1] = np.clip(non_transparent_pixels[:, 1], g_min, g_max)
|
||||
non_transparent_pixels[:, 2] = np.clip(non_transparent_pixels[:, 2], b_min, b_max)
|
||||
|
||||
tiles = []
|
||||
for _ in range(256):
|
||||
color = non_transparent_pixels[np.random.randint(len(non_transparent_pixels))]
|
||||
tile = np.zeros((tile_height, tile_width, 3), dtype=np.uint8)
|
||||
tile[:, :] = color
|
||||
tiles.append(tile)
|
||||
|
||||
# Fill the transparent area with tiles
|
||||
filled_image = np.zeros((image.height, image.width, 3), dtype=np.uint8)
|
||||
|
||||
for x in range(image.width):
|
||||
for y in range(image.height):
|
||||
tile = tiles[np.random.randint(len(tiles))]
|
||||
try:
|
||||
filled_image[
|
||||
y - (y % tile_height) : y - (y % tile_height) + tile_height,
|
||||
x - (x % tile_width) : x - (x % tile_width) + tile_width,
|
||||
] = tile
|
||||
except ValueError:
|
||||
# Need to handle edge cases - literally
|
||||
pass
|
||||
|
||||
filled_image = Image.fromarray(filled_image) # Convert the filled tiles image to PIL
|
||||
image = Image.composite(
|
||||
image, filled_image, image.split()[-1]
|
||||
) # Composite the original image on top of the filled tiles
|
||||
return image
|
||||
@@ -1,67 +0,0 @@
|
||||
"""
|
||||
This module defines a singleton object, "patchmatch" that
|
||||
wraps the actual patchmatch object. It respects the global
|
||||
"try_patchmatch" attribute, so that patchmatch loading can
|
||||
be suppressed or deferred
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
Thin class wrapper around the patchmatch function.
|
||||
"""
|
||||
|
||||
patch_match = None
|
||||
tried_load: bool = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _load_patch_match(cls):
|
||||
if cls.tried_load:
|
||||
return
|
||||
if get_config().patchmatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
logger.info("Patchmatch initialized")
|
||||
cls.patch_match = pm
|
||||
else:
|
||||
logger.info("Patchmatch not loaded (nonfatal)")
|
||||
else:
|
||||
logger.info("Patchmatch loading disabled")
|
||||
cls.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def patchmatch_available(cls) -> bool:
|
||||
cls._load_patch_match()
|
||||
if not cls.patch_match:
|
||||
return False
|
||||
return cls.patch_match.patchmatch_available
|
||||
|
||||
@classmethod
|
||||
def inpaint(cls, image: Image.Image) -> Image.Image:
|
||||
if cls.patch_match is None or not cls.patchmatch_available():
|
||||
return image
|
||||
|
||||
np_image = np.array(image)
|
||||
mask = 255 - np_image[:, :, 3]
|
||||
infilled = cls.patch_match.inpaint(np_image[:, :, :3], mask, patch_size=3)
|
||||
return Image.fromarray(infilled, mode="RGB")
|
||||
|
||||
|
||||
def infill_patchmatch(image: Image.Image) -> Image.Image:
|
||||
IS_PATCHMATCH_AVAILABLE = PatchMatch.patchmatch_available()
|
||||
|
||||
if not IS_PATCHMATCH_AVAILABLE:
|
||||
logger.warning("PatchMatch is not available on this system")
|
||||
return image
|
||||
|
||||
return PatchMatch.inpaint(image)
|
||||
|
Before Width: | Height: | Size: 45 KiB |
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 36 KiB |
|
Before Width: | Height: | Size: 33 KiB |
|
Before Width: | Height: | Size: 21 KiB |
|
Before Width: | Height: | Size: 39 KiB |
|
Before Width: | Height: | Size: 42 KiB |
|
Before Width: | Height: | Size: 48 KiB |
|
Before Width: | Height: | Size: 49 KiB |
|
Before Width: | Height: | Size: 60 KiB |
@@ -1,95 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"Smoke test for the tile infill\"\"\"\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import Optional\n",
|
||||
"from PIL import Image\n",
|
||||
"from invokeai.backend.image_util.infill_methods.tile import infill_tile\n",
|
||||
"\n",
|
||||
"images: list[tuple[str, Image.Image]] = []\n",
|
||||
"\n",
|
||||
"for i in sorted(Path(\"./test_images/\").glob(\"*.webp\")):\n",
|
||||
" images.append((i.name, Image.open(i)))\n",
|
||||
" images.append((i.name, Image.open(i).transpose(Image.FLIP_LEFT_RIGHT)))\n",
|
||||
" images.append((i.name, Image.open(i).transpose(Image.FLIP_TOP_BOTTOM)))\n",
|
||||
" images.append((i.name, Image.open(i).resize((512, 512))))\n",
|
||||
" images.append((i.name, Image.open(i).resize((1234, 461))))\n",
|
||||
"\n",
|
||||
"outputs: list[tuple[str, Image.Image, Image.Image, Optional[Image.Image]]] = []\n",
|
||||
"\n",
|
||||
"for name, image in images:\n",
|
||||
" try:\n",
|
||||
" output = infill_tile(image, seed=0, tile_size=32)\n",
|
||||
" outputs.append((name, image, output.infilled, output.tile_image))\n",
|
||||
" except ValueError as e:\n",
|
||||
" print(f\"Skipping image {name}: {e}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Display the images in jupyter notebook\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from PIL import ImageOps\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(len(outputs), 3, figsize=(10, 3 * len(outputs)))\n",
|
||||
"plt.subplots_adjust(hspace=0)\n",
|
||||
"\n",
|
||||
"for i, (name, original, infilled, tile_image) in enumerate(outputs):\n",
|
||||
" # Add a border to each image, helps to see the edges\n",
|
||||
" size = original.size\n",
|
||||
" original = ImageOps.expand(original, border=5, fill=\"red\")\n",
|
||||
" filled = ImageOps.expand(infilled, border=5, fill=\"red\")\n",
|
||||
" if tile_image:\n",
|
||||
" tile_image = ImageOps.expand(tile_image, border=5, fill=\"red\")\n",
|
||||
"\n",
|
||||
" axes[i, 0].imshow(original)\n",
|
||||
" axes[i, 0].axis(\"off\")\n",
|
||||
" axes[i, 0].set_title(f\"Original ({name} - {size})\")\n",
|
||||
"\n",
|
||||
" if tile_image:\n",
|
||||
" axes[i, 1].imshow(tile_image)\n",
|
||||
" axes[i, 1].axis(\"off\")\n",
|
||||
" axes[i, 1].set_title(\"Tile Image\")\n",
|
||||
" else:\n",
|
||||
" axes[i, 1].axis(\"off\")\n",
|
||||
" axes[i, 1].set_title(\"NO TILES GENERATED (NO TRANSPARENCY)\")\n",
|
||||
"\n",
|
||||
" axes[i, 2].imshow(filled)\n",
|
||||
" axes[i, 2].axis(\"off\")\n",
|
||||
" axes[i, 2].set_title(\"Filled\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".invokeai",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def create_tile_pool(img_array: np.ndarray, tile_size: tuple[int, int]) -> list[np.ndarray]:
|
||||
"""
|
||||
Create a pool of tiles from non-transparent areas of the image by systematically walking through the image.
|
||||
|
||||
Args:
|
||||
img_array: numpy array of the image.
|
||||
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
|
||||
|
||||
Returns:
|
||||
A list of numpy arrays, each representing a tile.
|
||||
"""
|
||||
tiles: list[np.ndarray] = []
|
||||
rows, cols = img_array.shape[:2]
|
||||
tile_width, tile_height = tile_size
|
||||
|
||||
for y in range(0, rows - tile_height + 1, tile_height):
|
||||
for x in range(0, cols - tile_width + 1, tile_width):
|
||||
tile = img_array[y : y + tile_height, x : x + tile_width]
|
||||
# Check if the image has an alpha channel and the tile is completely opaque
|
||||
if img_array.shape[2] == 4 and np.all(tile[:, :, 3] == 255):
|
||||
tiles.append(tile)
|
||||
elif img_array.shape[2] == 3: # If no alpha channel, append the tile
|
||||
tiles.append(tile)
|
||||
|
||||
if not tiles:
|
||||
raise ValueError(
|
||||
"Not enough opaque pixels to generate any tiles. Use a smaller tile size or a different image."
|
||||
)
|
||||
|
||||
return tiles
|
||||
|
||||
|
||||
def create_filled_image(
|
||||
img_array: np.ndarray, tile_pool: list[np.ndarray], tile_size: tuple[int, int], seed: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Create an image of the same dimensions as the original, filled entirely with tiles from the pool.
|
||||
|
||||
Args:
|
||||
img_array: numpy array of the original image.
|
||||
tile_pool: A list of numpy arrays, each representing a tile.
|
||||
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
|
||||
|
||||
Returns:
|
||||
A numpy array representing the filled image.
|
||||
"""
|
||||
|
||||
rows, cols, _ = img_array.shape
|
||||
tile_width, tile_height = tile_size
|
||||
|
||||
# Prep an empty RGB image
|
||||
filled_img_array = np.zeros((rows, cols, 3), dtype=img_array.dtype)
|
||||
|
||||
# Make the random tile selection reproducible
|
||||
rng = np.random.default_rng(seed)
|
||||
|
||||
for y in range(0, rows, tile_height):
|
||||
for x in range(0, cols, tile_width):
|
||||
# Pick a random tile from the pool
|
||||
tile = tile_pool[rng.integers(len(tile_pool))]
|
||||
|
||||
# Calculate the space available (may be less than tile size near the edges)
|
||||
space_y = min(tile_height, rows - y)
|
||||
space_x = min(tile_width, cols - x)
|
||||
|
||||
# Crop the tile if necessary to fit into the available space
|
||||
cropped_tile = tile[:space_y, :space_x, :3]
|
||||
|
||||
# Fill the available space with the (possibly cropped) tile
|
||||
filled_img_array[y : y + space_y, x : x + space_x, :3] = cropped_tile
|
||||
|
||||
return filled_img_array
|
||||
|
||||
|
||||
@dataclass
|
||||
class InfillTileOutput:
|
||||
infilled: Image.Image
|
||||
tile_image: Optional[Image.Image] = None
|
||||
|
||||
|
||||
def infill_tile(image_to_infill: Image.Image, seed: int, tile_size: int) -> InfillTileOutput:
|
||||
"""Infills an image with random tiles from the image itself.
|
||||
|
||||
If the image is not an RGBA image, it is returned untouched.
|
||||
|
||||
Args:
|
||||
image: The image to infill.
|
||||
tile_size: The size of the tiles to use for infilling.
|
||||
|
||||
Raises:
|
||||
ValueError: If there are not enough opaque pixels to generate any tiles.
|
||||
"""
|
||||
|
||||
if image_to_infill.mode != "RGBA":
|
||||
return InfillTileOutput(infilled=image_to_infill)
|
||||
|
||||
# Internally, we want a tuple of (tile_width, tile_height). In the future, the tile size can be any rectangle.
|
||||
_tile_size = (tile_size, tile_size)
|
||||
np_image = np.array(image_to_infill, dtype=np.uint8)
|
||||
|
||||
# Create the pool of tiles that we will use to infill
|
||||
tile_pool = create_tile_pool(np_image, _tile_size)
|
||||
|
||||
# Create an image from the tiles, same size as the original
|
||||
tile_np_image = create_filled_image(np_image, tile_pool, _tile_size, seed)
|
||||
|
||||
# Paste the OG image over the tile image, effectively infilling the area
|
||||
tile_image = Image.fromarray(tile_np_image, "RGB")
|
||||
infilled = tile_image.copy()
|
||||
infilled.paste(image_to_infill, (0, 0), image_to_infill.split()[-1])
|
||||
|
||||
# I think we want this to be "RGBA"?
|
||||
infilled.convert("RGBA")
|
||||
|
||||
return InfillTileOutput(infilled=infilled, tile_image=tile_image)
|
||||
@@ -7,7 +7,6 @@ from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
|
||||
@@ -31,14 +30,6 @@ class LaMA:
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = choose_torch_device()
|
||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||
|
||||
if not model_location.exists():
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=model_location,
|
||||
)
|
||||
|
||||
model = load_jit_model(model_location, device)
|
||||
|
||||
image = np.asarray(input_image.convert("RGB"))
|
||||
49
invokeai/backend/image_util/patchmatch.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
This module defines a singleton object, "patchmatch" that
|
||||
wraps the actual patchmatch object. It respects the global
|
||||
"try_patchmatch" attribute, so that patchmatch loading can
|
||||
be suppressed or deferred
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
Thin class wrapper around the patchmatch function.
|
||||
"""
|
||||
|
||||
patch_match = None
|
||||
tried_load: bool = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _load_patch_match(self):
|
||||
if self.tried_load:
|
||||
return
|
||||
if get_config().patchmatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
logger.info("Patchmatch initialized")
|
||||
else:
|
||||
logger.info("Patchmatch not loaded (nonfatal)")
|
||||
self.patch_match = pm
|
||||
else:
|
||||
logger.info("Patchmatch loading disabled")
|
||||
self.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def patchmatch_available(self) -> bool:
|
||||
self._load_patch_match()
|
||||
return self.patch_match and self.patch_match.patchmatch_available
|
||||
|
||||
@classmethod
|
||||
def inpaint(self, *args, **kwargs) -> np.ndarray:
|
||||
if self.patchmatch_available():
|
||||
return self.patch_match.inpaint(*args, **kwargs)
|
||||
@@ -12,7 +12,6 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||
|
||||
from ..raw_model import RawModel
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
@@ -102,7 +101,7 @@ class MLPProjModel(torch.nn.Module):
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class IPAdapter(RawModel):
|
||||
class IPAdapter(torch.nn.Module):
|
||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||
|
||||
def __init__(
|
||||
@@ -112,6 +111,7 @@ class IPAdapter(RawModel):
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_tokens: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
@@ -1,624 +0,0 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
import bisect
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
|
||||
from .raw_model import RawModel
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
if "lora_mid.weight" in values:
|
||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
||||
else:
|
||||
self.mid = None
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
self.w1 = None
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
self.w2 = None
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
self.rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
self.rank = values["lokr_w2_b"].shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
|
||||
if len(values.keys()) > 1:
|
||||
_keys = list(values.keys())
|
||||
_keys.remove("diff")
|
||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem,
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "weight" in values and "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
0
invokeai/backend/lora/__init__.py
Normal file
42
invokeai/backend/lora/full_layer.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
|
||||
if len(values.keys()) > 1:
|
||||
_keys = list(values.keys())
|
||||
_keys.remove("diff")
|
||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
45
invokeai/backend/lora/ia3_layer.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
69
invokeai/backend/lora/loha_layer.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
110
invokeai/backend/lora/lokr_layer.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
self.w1 = None
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
self.w2 = None
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
self.rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
self.rank = values["lokr_w2_b"].shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
81
invokeai/backend/lora/lora_layer.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.lora_layer_base import LoRALayerBase
|
||||
|
||||
|
||||
class LoRALayer(LoRALayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
|
||||
self.mid: Optional[torch.Tensor] = values.get("lora_mid.weight", None)
|
||||
self.dora_scale: Optional[torch.Tensor] = values.get("dora_scale", None)
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def _apply_dora(self, orig_weight: torch.Tensor, lora_weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply DoRA to the weight matrix.
|
||||
|
||||
This function is based roughly on the reference implementation in PEFT, but handles scaling in a slightly
|
||||
different way:
|
||||
https://github.com/huggingface/peft/blob/26726bf1ddee6ca75ed4e1bfd292094526707a78/src/peft/tuners/lora/layer.py#L421-L433
|
||||
|
||||
"""
|
||||
# Merge the original weight with the LoRA weight.
|
||||
merged_weight = orig_weight + lora_weight
|
||||
|
||||
# Calculate the vector-wise L2 norm of the weight matrix across each column vector.
|
||||
weight_norm: torch.Tensor = torch.linalg.norm(merged_weight, dim=1)
|
||||
|
||||
dora_factor = self.dora_scale / weight_norm
|
||||
new_weight = dora_factor * merged_weight
|
||||
|
||||
# TODO(ryand): This is wasteful. We already have the final weight, but we calculate the diff, because that is
|
||||
# what the `get_weight()` API is expected to return. If we do refactor this, we'll have to give some thought to
|
||||
# how lora weight scaling should be applied - having the full weight diff makes this easy.
|
||||
weight_diff = new_weight - orig_weight
|
||||
return weight_diff
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
if self.dora_scale is not None:
|
||||
assert orig_weight is not None
|
||||
weight = self._apply_dora(orig_weight, weight)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
if self.dora_scale is not None:
|
||||
self.dora_scale = self.dora_scale.to(device=device, dtype=dtype)
|
||||
55
invokeai/backend/lora/lora_layer_base.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
111
invokeai/backend/lora/lora_model.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.full_layer import FullLayer
|
||||
from invokeai.backend.lora.ia3_layer import IA3Layer
|
||||
from invokeai.backend.lora.loha_layer import LoHALayer
|
||||
from invokeai.backend.lora.lokr_layer import LoKRLayer
|
||||
from invokeai.backend.lora.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.sdxl_state_dict_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.util.serialization import load_state_dict
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
|
||||
class LoRAModelRaw(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: dict[str, AnyLoRALayer],
|
||||
):
|
||||
super().__init__()
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
):
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
file_path = Path(file_path)
|
||||
|
||||
model_name = file_path.stem
|
||||
|
||||
sd = load_state_dict(file_path, device=str(device))
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "weight" in values and "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown lora layer module in {model_name}: {layer_key}: {list(values.keys())}")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
layers[layer_key] = layer
|
||||
|
||||
return cls(name=model_name, layers=layers)
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
137
invokeai/backend/lora/lora_model_patcher.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from invokeai.backend.lora.lora_model import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
|
||||
|
||||
class LoraModelPatcher:
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_sdxl_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_sdxl_lora_text_encoder2(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: AnyModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
):
|
||||
original_weights = {}
|
||||
try:
|
||||
with torch.no_grad():
|
||||
for lora, lora_weight in loras:
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||
# should be improved in the following ways:
|
||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||
# LoRA model is applied.
|
||||
# 2. From an API perspective, there's no reason that the `LoraModelPatcher` should be aware of
|
||||
# the intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
if module_key not in original_weights:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(device=torch.device("cpu"))
|
||||
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
module.weight += layer_weight.to(dtype=dtype)
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
finally:
|
||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||
with torch.no_grad():
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(weight)
|
||||
157
invokeai/backend/lora/sdxl_state_dict_utils.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import bisect
|
||||
from typing import TypeVar
|
||||
|
||||
|
||||
def make_sdxl_unet_conversion_map() -> list[tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.
|
||||
|
||||
Ported from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
"""
|
||||
unet_conversion_map_layer: list[tuple[str, str]] = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map: list[tuple[str, str]] = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, T]) -> dict[str, T]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict: dict[str, T] = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Re-export frequently-used symbols from the Model Manager backend."""
|
||||
|
||||
from .config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
@@ -18,7 +17,6 @@ from .probe import ModelProbe
|
||||
from .search import ModelSearch
|
||||
|
||||
__all__ = [
|
||||
"AnyModel",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelRepoVariant",
|
||||
|
||||
12
invokeai/backend/model_manager/any_model_type.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora.lora_model import LoRAModelRaw
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, LoRAModelRaw, TextualInversionModelRaw, IAIOnnxRuntimeModel]
|
||||
@@ -24,20 +24,12 @@ import time
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, TypeAlias, Union
|
||||
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
|
||||
@@ -15,7 +15,7 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
)
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from . import AnyModel
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
|
||||
|
||||
def convert_ldm_vae_to_diffusers(
|
||||
|
||||
@@ -10,8 +10,8 @@ from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
@@ -7,11 +7,11 @@ from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
InvalidModelConfigException,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
|
||||
@@ -14,7 +14,8 @@ from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
class ModelLockerBase(ABC):
|
||||
@@ -117,7 +118,7 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
def stats(self) -> CacheStats:
|
||||
"""Return collected CacheStats object."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -28,7 +28,8 @@ from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -269,6 +270,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
# may raise an exception here if insufficient GPU VRAM
|
||||
self._check_free_vram(target_device, cache_entry.size)
|
||||
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
cache_entry.model.to(target_device)
|
||||
@@ -326,11 +330,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
def make_room(self, model_size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = size
|
||||
bytes_needed = model_size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = self.cache_size()
|
||||
|
||||
@@ -385,7 +389,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
# 1 from onnx runtime object
|
||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
@@ -417,3 +421,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
mps.empty_cache()
|
||||
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
|
||||
if target_device.type != "cuda":
|
||||
return
|
||||
vram_device = ( # mem_get_info() needs an indexed device
|
||||
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
|
||||
)
|
||||
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
||||
if needed_size > free_mem:
|
||||
needed_gb = round(needed_size / GIG, 2)
|
||||
free_gb = round(free_mem / GIG, 2)
|
||||
raise torch.cuda.OutOfMemoryError(
|
||||
f"Insufficient VRAM to load model, requested {needed_gb}GB but only had {free_gb}GB free"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ Base class and implementation of a class that moves models in and out of VRAM.
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
|
||||
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
||||
|
||||
@@ -34,6 +34,7 @@ class ModelLocker(ModelLockerBase):
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
||||
self._cache_entry.lock()
|
||||
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
@@ -50,7 +51,6 @@ class ModelLocker(ModelLockerBase):
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
|
||||
return self.model
|
||||
|
||||
def unlock(self) -> None:
|
||||
|
||||
@@ -5,12 +5,12 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
@@ -17,6 +16,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@@ -7,9 +7,9 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
@@ -25,7 +25,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model_path = Path(config.path)
|
||||
model: RawModel = build_ip_adapter(
|
||||
model = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=model_path,
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
|
||||
@@ -6,15 +6,15 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_model import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
|
||||
@@ -6,13 +6,13 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
@@ -13,6 +12,7 @@ from invokeai.backend.model_manager import (
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
|
||||
@@ -5,13 +5,13 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@@ -14,7 +14,8 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Optional
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ def skip_torch_weight_init() -> Generator[None, None, None]:
|
||||
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
||||
monkey-patches common torch layers to skip the weight initialization step.
|
||||
"""
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding, torch.nn.LayerNorm]
|
||||
saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules]
|
||||
|
||||
try:
|
||||
|
||||
@@ -13,157 +13,14 @@ from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.lora.lora_model import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
from .lora import LoRAModelRaw
|
||||
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
|
||||
"""
|
||||
loras = [
|
||||
(lora_model1, 0.7),
|
||||
(lora_model2, 0.4),
|
||||
]
|
||||
with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
# unet with applied loras
|
||||
# unmodified unet
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
|
||||
if not lora_key.startswith(prefix):
|
||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
||||
|
||||
module = model
|
||||
module_key = ""
|
||||
key_parts = lora_key[len(prefix) :].split("_")
|
||||
|
||||
submodule_name = key_parts.pop(0)
|
||||
|
||||
while len(key_parts) > 0:
|
||||
try:
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key += "." + submodule_name
|
||||
submodule_name = key_parts.pop(0)
|
||||
except Exception:
|
||||
submodule_name += "_" + key_parts.pop(0)
|
||||
|
||||
module = module.get_submodule(submodule_name)
|
||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
||||
|
||||
return (module_key, module)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_sdxl_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_sdxl_lora_text_encoder2(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: AnyModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
original_weights = {}
|
||||
try:
|
||||
with torch.no_grad():
|
||||
for lora, lora_weight in loras:
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||
# should be improved in the following ways:
|
||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||
# LoRA model is applied.
|
||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
if module_key not in original_weights:
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(device=torch.device("cpu"))
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
assert hasattr(layer_weight, "reshape")
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
module.weight += layer_weight.to(dtype=dtype)
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
finally:
|
||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||
with torch.no_grad():
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
|
||||
@@ -6,17 +6,16 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
|
||||
|
||||
# NOTE FROM LS: This was copied from Stalker's original implementation.
|
||||
# I have not yet gone through and fixed all the type hints
|
||||
class IAIOnnxRuntimeModel(RawModel):
|
||||
class IAIOnnxRuntimeModel(torch.nn.Module):
|
||||
class _tensor_access:
|
||||
def __init__(self, model): # type: ignore
|
||||
self.model = model
|
||||
@@ -103,7 +102,7 @@ class IAIOnnxRuntimeModel(RawModel):
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=False)
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.proto = onnx.load(model_path, load_external_data=True)
|
||||
# self.data = dict()
|
||||
# for tensor in self.proto.graph.initializer:
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
|
||||
class RawModel:
|
||||
"""Base class for 'Raw' model wrappers."""
|
||||
@@ -9,10 +9,8 @@ from safetensors.torch import load_file
|
||||
from transformers import CLIPTokenizer
|
||||
from typing_extensions import Self
|
||||
|
||||
from .raw_model import RawModel
|
||||
|
||||
|
||||
class TextualInversionModelRaw(RawModel):
|
||||
class TextualInversionModelRaw(torch.nn.Module):
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||
|
||||
|
||||
37
invokeai/backend/util/serialization.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def state_dict_to(
|
||||
state_dict: dict[str, torch.Tensor], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
new_state_dict: dict[str, torch.Tensor] = {}
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict[k] = v.to(device=device, dtype=dtype, non_blocking=True)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path: Union[str, Path], device: str = "cpu") -> Any:
|
||||
"""Load a state_dict from a file that may be in either PyTorch or safetensors format. The file format is inferred
|
||||
from the file extension.
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(
|
||||
file_path,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
# weights_only=True is used to address a security vulnerability that allows arbitrary code execution.
|
||||
# This option was first introduced in https://github.com/pytorch/pytorch/pull/86812.
|
||||
#
|
||||
# mmap=True is used to both reduce memory usage and speed up loading. This setting causes torch.load() to more
|
||||
# closely mirror the behaviour of safetensors.torch.load_file(). This option was first introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/102549. The discussion on that PR provides helpful context.
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True, mmap=True)
|
||||
|
||||
return state_dict
|
||||
@@ -291,6 +291,7 @@
|
||||
"canvasMerged": "تم دمج الخط",
|
||||
"sentToImageToImage": "تم إرسال إلى صورة إلى صورة",
|
||||
"sentToUnifiedCanvas": "تم إرسال إلى لوحة موحدة",
|
||||
"parametersSet": "تم تعيين المعلمات",
|
||||
"parametersNotSet": "لم يتم تعيين المعلمات",
|
||||
"metadataLoadFailed": "فشل تحميل البيانات الوصفية"
|
||||
},
|
||||
|
||||
@@ -75,8 +75,7 @@
|
||||
"copy": "Kopieren",
|
||||
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
|
||||
"toResolve": "Lösen",
|
||||
"add": "Hinzufügen",
|
||||
"loglevel": "Protokoll Stufe"
|
||||
"add": "Hinzufügen"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@@ -389,14 +388,7 @@
|
||||
"vaePrecision": "VAE-Präzision",
|
||||
"variant": "Variante",
|
||||
"modelDeleteFailed": "Modell konnte nicht gelöscht werden",
|
||||
"noModelSelected": "Kein Modell ausgewählt",
|
||||
"huggingFace": "HuggingFace",
|
||||
"defaultSettings": "Standardeinstellungen",
|
||||
"edit": "Bearbeiten",
|
||||
"cancel": "Stornieren",
|
||||
"defaultSettingsSaved": "Standardeinstellungen gespeichert",
|
||||
"addModels": "Model hinzufügen",
|
||||
"deleteModelImage": "Lösche Model Bild"
|
||||
"noModelSelected": "Kein Modell ausgewählt"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Bilder",
|
||||
@@ -480,6 +472,7 @@
|
||||
"canvasMerged": "Leinwand zusammengeführt",
|
||||
"sentToImageToImage": "Gesendet an Bild zu Bild",
|
||||
"sentToUnifiedCanvas": "Gesendet an Leinwand",
|
||||
"parametersSet": "Parameter festlegen",
|
||||
"parametersNotSet": "Parameter nicht festgelegt",
|
||||
"metadataLoadFailed": "Metadaten konnten nicht geladen werden",
|
||||
"setCanvasInitialImage": "Ausgangsbild setzen",
|
||||
@@ -684,8 +677,7 @@
|
||||
"body": "Körper",
|
||||
"hands": "Hände",
|
||||
"dwOpenpose": "DW Openpose",
|
||||
"dwOpenposeDescription": "Posenschätzung mit DW Openpose",
|
||||
"selectCLIPVisionModel": "Wähle ein CLIP Vision Model aus"
|
||||
"dwOpenposeDescription": "Posenschätzung mit DW Openpose"
|
||||
},
|
||||
"queue": {
|
||||
"status": "Status",
|
||||
@@ -773,10 +765,7 @@
|
||||
"recallParameters": "Parameter wiederherstellen",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"allPrompts": "Alle Prompts",
|
||||
"imageDimensions": "Bilder Auslösungen",
|
||||
"parameterSet": "Parameter {{parameter}} setzen",
|
||||
"recallParameter": "{{label}} Abrufen",
|
||||
"parsingFailed": "Parsing Fehlgeschlagen"
|
||||
"imageDimensions": "Bilder Auslösungen"
|
||||
},
|
||||
"popovers": {
|
||||
"noiseUseCPU": {
|
||||
@@ -1041,8 +1030,7 @@
|
||||
"title": "Bild"
|
||||
},
|
||||
"advanced": {
|
||||
"title": "Erweitert",
|
||||
"options": "$t(accordions.advanced.title) Optionen"
|
||||
"title": "Erweitert"
|
||||
},
|
||||
"control": {
|
||||
"title": "Kontrolle"
|
||||
|
||||
@@ -684,7 +684,6 @@
|
||||
"noModelsInstalled": "No Models Installed",
|
||||
"noModelsInstalledDesc1": "Install models with the",
|
||||
"noModelSelected": "No Model Selected",
|
||||
"noMatchingModels": "No matching Models",
|
||||
"none": "none",
|
||||
"path": "Path",
|
||||
"pathToConfig": "Path To Config",
|
||||
@@ -888,11 +887,6 @@
|
||||
"imageFit": "Fit Initial Image To Output Size",
|
||||
"images": "Images",
|
||||
"infillMethod": "Infill Method",
|
||||
"infillMosaicTileWidth": "Tile Width",
|
||||
"infillMosaicTileHeight": "Tile Height",
|
||||
"infillMosaicMinColor": "Min Color",
|
||||
"infillMosaicMaxColor": "Max Color",
|
||||
"infillColorValue": "Fill Color",
|
||||
"info": "Info",
|
||||
"invoke": {
|
||||
"addingImagesTo": "Adding images to",
|
||||
@@ -1041,10 +1035,10 @@
|
||||
"metadataLoadFailed": "Failed to load metadata",
|
||||
"modelAddedSimple": "Model Added to Queue",
|
||||
"modelImportCanceled": "Model Import Canceled",
|
||||
"parameters": "Parameters",
|
||||
"parameterNotSet": "{{parameter}} not set",
|
||||
"parameterSet": "{{parameter}} set",
|
||||
"parametersNotSet": "Parameters Not Set",
|
||||
"parametersSet": "Parameters Set",
|
||||
"problemCopyingCanvas": "Problem Copying Canvas",
|
||||
"problemCopyingCanvasDesc": "Unable to export base layer",
|
||||
"problemCopyingImage": "Unable to Copy Image",
|
||||
@@ -1423,7 +1417,6 @@
|
||||
"eraseBoundingBox": "Erase Bounding Box",
|
||||
"eraser": "Eraser",
|
||||
"fillBoundingBox": "Fill Bounding Box",
|
||||
"initialFitImageSize": "Fit Image Size on Drop",
|
||||
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
||||
"layer": "Layer",
|
||||
"limitStrokesToBox": "Limit Strokes to Box",
|
||||
|
||||
@@ -363,6 +363,7 @@
|
||||
"canvasMerged": "Lienzo consolidado",
|
||||
"sentToImageToImage": "Enviar hacia Imagen a Imagen",
|
||||
"sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado",
|
||||
"parametersSet": "Parámetros establecidos",
|
||||
"parametersNotSet": "Parámetros no establecidos",
|
||||
"metadataLoadFailed": "Error al cargar metadatos",
|
||||
"serverError": "Error en el servidor",
|
||||
|
||||
@@ -298,6 +298,7 @@
|
||||
"canvasMerged": "Canvas fusionné",
|
||||
"sentToImageToImage": "Envoyé à Image à Image",
|
||||
"sentToUnifiedCanvas": "Envoyé à Canvas unifié",
|
||||
"parametersSet": "Paramètres définis",
|
||||
"parametersNotSet": "Paramètres non définis",
|
||||
"metadataLoadFailed": "Échec du chargement des métadonnées"
|
||||
},
|
||||
|
||||
@@ -306,6 +306,7 @@
|
||||
"canvasMerged": "קנבס מוזג",
|
||||
"sentToImageToImage": "נשלח לתמונה לתמונה",
|
||||
"sentToUnifiedCanvas": "נשלח אל קנבס מאוחד",
|
||||
"parametersSet": "הגדרת פרמטרים",
|
||||
"parametersNotSet": "פרמטרים לא הוגדרו",
|
||||
"metadataLoadFailed": "טעינת מטא-נתונים נכשלה"
|
||||
},
|
||||
|
||||
@@ -366,7 +366,7 @@
|
||||
"modelConverted": "Modello convertito",
|
||||
"alpha": "Alpha",
|
||||
"convertToDiffusersHelpText1": "Questo modello verrà convertito nel formato 🧨 Diffusori.",
|
||||
"convertToDiffusersHelpText3": "Il file del modello su disco verrà eliminato se si trova nella cartella principale di InvokeAI. Se si trova invece in una posizione personalizzata, NON verrà eliminato.",
|
||||
"convertToDiffusersHelpText3": "Il file Checkpoint su disco verrà eliminato se si trova nella cartella principale di InvokeAI. Se si trova invece in una posizione personalizzata, NON verrà eliminato.",
|
||||
"v2_base": "v2 (512px)",
|
||||
"v2_768": "v2 (768px)",
|
||||
"none": "nessuno",
|
||||
@@ -443,8 +443,7 @@
|
||||
"noModelsInstalled": "Nessun modello installato",
|
||||
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
|
||||
"main": "Principali",
|
||||
"noModelsInstalledDesc1": "Installa i modelli con",
|
||||
"ipAdapters": "Adattatori IP"
|
||||
"noModelsInstalledDesc1": "Installa i modelli con"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@@ -569,6 +568,7 @@
|
||||
"canvasMerged": "Tela unita",
|
||||
"sentToImageToImage": "Inviato a Immagine a Immagine",
|
||||
"sentToUnifiedCanvas": "Inviato a Tela Unificata",
|
||||
"parametersSet": "Parametri impostati",
|
||||
"parametersNotSet": "Parametri non impostati",
|
||||
"metadataLoadFailed": "Impossibile caricare i metadati",
|
||||
"serverError": "Errore del Server",
|
||||
@@ -937,8 +937,7 @@
|
||||
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
||||
"mediapipeFace": "Mediapipe Volto",
|
||||
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
|
||||
"selectCLIPVisionModel": "Seleziona un modello CLIP Vision"
|
||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))"
|
||||
},
|
||||
"queue": {
|
||||
"queueFront": "Aggiungi all'inizio della coda",
|
||||
|
||||
@@ -420,6 +420,7 @@
|
||||
"canvasMerged": "Canvas samengevoegd",
|
||||
"sentToImageToImage": "Gestuurd naar Afbeelding naar afbeelding",
|
||||
"sentToUnifiedCanvas": "Gestuurd naar Centraal canvas",
|
||||
"parametersSet": "Parameters ingesteld",
|
||||
"parametersNotSet": "Parameters niet ingesteld",
|
||||
"metadataLoadFailed": "Fout bij laden metagegevens",
|
||||
"serverError": "Serverfout",
|
||||
|
||||
@@ -267,6 +267,7 @@
|
||||
"canvasMerged": "Scalono widoczne warstwy",
|
||||
"sentToImageToImage": "Wysłano do Obraz na obraz",
|
||||
"sentToUnifiedCanvas": "Wysłano do trybu uniwersalnego",
|
||||
"parametersSet": "Ustawiono parametry",
|
||||
"parametersNotSet": "Nie ustawiono parametrów",
|
||||
"metadataLoadFailed": "Błąd wczytywania metadanych"
|
||||
},
|
||||
|
||||
@@ -310,6 +310,7 @@
|
||||
"canvasMerged": "Tela Fundida",
|
||||
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
|
||||
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
|
||||
"parametersSet": "Parâmetros Definidos",
|
||||
"parametersNotSet": "Parâmetros Não Definidos",
|
||||
"metadataLoadFailed": "Falha ao tentar carregar metadados"
|
||||
},
|
||||
|
||||
@@ -307,6 +307,7 @@
|
||||
"canvasMerged": "Tela Fundida",
|
||||
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
|
||||
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
|
||||
"parametersSet": "Parâmetros Definidos",
|
||||
"parametersNotSet": "Parâmetros Não Definidos",
|
||||
"metadataLoadFailed": "Falha ao tentar carregar metadados"
|
||||
},
|
||||
|
||||
@@ -575,6 +575,7 @@
|
||||
"canvasMerged": "Холст объединен",
|
||||
"sentToImageToImage": "Отправить в img2img",
|
||||
"sentToUnifiedCanvas": "Отправлено на Единый холст",
|
||||
"parametersSet": "Параметры заданы",
|
||||
"parametersNotSet": "Параметры не заданы",
|
||||
"metadataLoadFailed": "Не удалось загрузить метаданные",
|
||||
"serverError": "Ошибка сервера",
|
||||
|
||||
@@ -315,6 +315,7 @@
|
||||
"canvasMerged": "Полотно об'єднане",
|
||||
"sentToImageToImage": "Надіслати до img2img",
|
||||
"sentToUnifiedCanvas": "Надіслати на полотно",
|
||||
"parametersSet": "Параметри задані",
|
||||
"parametersNotSet": "Параметри не задані",
|
||||
"metadataLoadFailed": "Не вдалося завантажити метадані",
|
||||
"serverError": "Помилка сервера",
|
||||
|
||||
@@ -487,6 +487,7 @@
|
||||
"canvasMerged": "画布已合并",
|
||||
"sentToImageToImage": "已发送到图生图",
|
||||
"sentToUnifiedCanvas": "已发送到统一画布",
|
||||
"parametersSet": "参数已设定",
|
||||
"parametersNotSet": "参数未设定",
|
||||
"metadataLoadFailed": "加载元数据失败",
|
||||
"uploadFailedInvalidUploadDesc": "必须是单张的 PNG 或 JPEG 图片",
|
||||
|
||||
@@ -18,7 +18,6 @@ import {
|
||||
setShouldAutoSave,
|
||||
setShouldCropToBoundingBoxOnSave,
|
||||
setShouldDarkenOutsideBoundingBox,
|
||||
setShouldFitImageSize,
|
||||
setShouldInvertBrushSizeScrollDirection,
|
||||
setShouldRestrictStrokesToBox,
|
||||
setShouldShowCanvasDebugInfo,
|
||||
@@ -49,7 +48,6 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
const shouldSnapToGrid = useAppSelector((s) => s.canvas.shouldSnapToGrid);
|
||||
const shouldRestrictStrokesToBox = useAppSelector((s) => s.canvas.shouldRestrictStrokesToBox);
|
||||
const shouldAntialias = useAppSelector((s) => s.canvas.shouldAntialias);
|
||||
const shouldFitImageSize = useAppSelector((s) => s.canvas.shouldFitImageSize);
|
||||
|
||||
useHotkeys(
|
||||
['n'],
|
||||
@@ -104,10 +102,6 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAntialias(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldFitImageSize = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldFitImageSize(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
@@ -171,10 +165,6 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
<FormLabel>{t('unifiedCanvas.antialiasing')}</FormLabel>
|
||||
<Checkbox isChecked={shouldAntialias} onChange={handleChangeShouldAntialias} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('unifiedCanvas.initialFitImageSize')}</FormLabel>
|
||||
<Checkbox isChecked={shouldFitImageSize} onChange={handleChangeShouldFitImageSize} />
|
||||
</FormControl>
|
||||
</FormControlGroup>
|
||||
<ClearCanvasHistoryButtonModal />
|
||||
</Flex>
|
||||
|
||||
@@ -66,7 +66,6 @@ const initialCanvasState: CanvasState = {
|
||||
shouldAutoSave: false,
|
||||
shouldCropToBoundingBoxOnSave: false,
|
||||
shouldDarkenOutsideBoundingBox: false,
|
||||
shouldFitImageSize: true,
|
||||
shouldInvertBrushSizeScrollDirection: false,
|
||||
shouldLockBoundingBox: false,
|
||||
shouldPreserveMaskedArea: false,
|
||||
@@ -145,20 +144,12 @@ export const canvasSlice = createSlice({
|
||||
reducer: (state, action: PayloadActionWithOptimalDimension<ImageDTO>) => {
|
||||
const { width, height, image_name } = action.payload;
|
||||
const { optimalDimension } = action.meta;
|
||||
const { stageDimensions, shouldFitImageSize } = state;
|
||||
const { stageDimensions } = state;
|
||||
|
||||
const newBoundingBoxDimensions = shouldFitImageSize
|
||||
? {
|
||||
width: roundDownToMultiple(width, CANVAS_GRID_SIZE_FINE),
|
||||
height: roundDownToMultiple(height, CANVAS_GRID_SIZE_FINE),
|
||||
}
|
||||
: {
|
||||
width: roundDownToMultiple(clamp(width, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
||||
height: roundDownToMultiple(
|
||||
clamp(height, CANVAS_GRID_SIZE_FINE, optimalDimension),
|
||||
CANVAS_GRID_SIZE_FINE
|
||||
),
|
||||
};
|
||||
const newBoundingBoxDimensions = {
|
||||
width: roundDownToMultiple(clamp(width, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
||||
height: roundDownToMultiple(clamp(height, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
||||
};
|
||||
|
||||
const newBoundingBoxCoordinates = {
|
||||
x: roundToMultiple(width / 2 - newBoundingBoxDimensions.width / 2, CANVAS_GRID_SIZE_FINE),
|
||||
@@ -591,9 +582,6 @@ export const canvasSlice = createSlice({
|
||||
setShouldAntialias: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldAntialias = action.payload;
|
||||
},
|
||||
setShouldFitImageSize: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldFitImageSize = action.payload;
|
||||
},
|
||||
setShouldCropToBoundingBoxOnSave: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldCropToBoundingBoxOnSave = action.payload;
|
||||
},
|
||||
@@ -704,7 +692,6 @@ export const {
|
||||
setShouldRestrictStrokesToBox,
|
||||
stagingAreaInitialized,
|
||||
setShouldAntialias,
|
||||
setShouldFitImageSize,
|
||||
canvasResized,
|
||||
canvasBatchIdAdded,
|
||||
canvasBatchIdsReset,
|
||||
|
||||
@@ -120,7 +120,6 @@ export interface CanvasState {
|
||||
shouldAutoSave: boolean;
|
||||
shouldCropToBoundingBoxOnSave: boolean;
|
||||
shouldDarkenOutsideBoundingBox: boolean;
|
||||
shouldFitImageSize: boolean;
|
||||
shouldInvertBrushSizeScrollDirection: boolean;
|
||||
shouldLockBoundingBox: boolean;
|
||||
shouldPreserveMaskedArea: boolean;
|
||||
|
||||
@@ -33,7 +33,6 @@ const ImageMetadataActions = (props: Props) => {
|
||||
<MetadataItem metadata={metadata} handlers={handlers.scheduler} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.cfgScale} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.cfgRescaleMultiplier} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.initialImage} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.strength} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.hrfEnabled} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.hrfMethod} />
|
||||
|
||||
@@ -189,12 +189,6 @@ export const handlers = {
|
||||
recaller: recallers.cfgScale,
|
||||
}),
|
||||
height: buildHandlers({ getLabel: () => t('metadata.height'), parser: parsers.height, recaller: recallers.height }),
|
||||
initialImage: buildHandlers({
|
||||
getLabel: () => t('metadata.initImage'),
|
||||
parser: parsers.initialImage,
|
||||
recaller: recallers.initialImage,
|
||||
renderValue: async (imageDTO) => imageDTO.image_name,
|
||||
}),
|
||||
negativePrompt: buildHandlers({
|
||||
getLabel: () => t('metadata.negativePrompt'),
|
||||
parser: parsers.negativePrompt,
|
||||
@@ -411,6 +405,6 @@ export const parseAndRecallAllMetadata = async (metadata: unknown, skip: (keyof
|
||||
})
|
||||
);
|
||||
if (results.some((result) => result.status === 'fulfilled')) {
|
||||
parameterSetToast(t('toast.parameters'));
|
||||
parameterSetToast(t('toast.parametersSet'));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import {
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
@@ -58,8 +57,6 @@ import {
|
||||
isParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { get, isArray, isString } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
isControlNetModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
@@ -138,14 +135,6 @@ const parseCFGRescaleMultiplier: MetadataParseFunc<ParameterCFGRescaleMultiplier
|
||||
const parseScheduler: MetadataParseFunc<ParameterScheduler> = (metadata) =>
|
||||
getProperty(metadata, 'scheduler', isParameterScheduler);
|
||||
|
||||
const parseInitialImage: MetadataParseFunc<ImageDTO> = async (metadata) => {
|
||||
const imageName = await getProperty(metadata, 'init_image', isString);
|
||||
const imageDTORequest = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName));
|
||||
const imageDTO = await imageDTORequest.unwrap();
|
||||
imageDTORequest.unsubscribe();
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
const parseWidth: MetadataParseFunc<ParameterWidth> = (metadata) => getProperty(metadata, 'width', isParameterWidth);
|
||||
|
||||
const parseHeight: MetadataParseFunc<ParameterHeight> = (metadata) =>
|
||||
@@ -413,7 +402,6 @@ export const parsers = {
|
||||
cfgScale: parseCFGScale,
|
||||
cfgRescaleMultiplier: parseCFGRescaleMultiplier,
|
||||
scheduler: parseScheduler,
|
||||
initialImage: parseInitialImage,
|
||||
width: parseWidth,
|
||||
height: parseHeight,
|
||||
steps: parseSteps,
|
||||
|
||||
@@ -17,7 +17,6 @@ import type {
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import {
|
||||
heightRecalled,
|
||||
initialImageChanged,
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setImg2imgStrength,
|
||||
@@ -62,7 +61,6 @@ import {
|
||||
setRefinerStart,
|
||||
setRefinerSteps,
|
||||
} from 'features/sdxl/store/sdxlSlice';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
|
||||
getStore().dispatch(setPositivePrompt(positivePrompt));
|
||||
@@ -96,10 +94,6 @@ const recallScheduler: MetadataRecallFunc<ParameterScheduler> = (scheduler) => {
|
||||
getStore().dispatch(setScheduler(scheduler));
|
||||
};
|
||||
|
||||
const recallInitialImage: MetadataRecallFunc<ImageDTO> = async (imageDTO) => {
|
||||
getStore().dispatch(initialImageChanged(imageDTO));
|
||||
};
|
||||
|
||||
const recallWidth: MetadataRecallFunc<ParameterWidth> = (width) => {
|
||||
getStore().dispatch(widthRecalled(width));
|
||||
};
|
||||
@@ -241,7 +235,6 @@ export const recallers = {
|
||||
cfgScale: recallCFGScale,
|
||||
cfgRescaleMultiplier: recallCFGRescaleMultiplier,
|
||||
scheduler: recallScheduler,
|
||||
initialImage: recallInitialImage,
|
||||
width: recallWidth,
|
||||
height: recallHeight,
|
||||
steps: recallSteps,
|
||||
|
||||
@@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
|
||||
|
||||
type ModelManagerState = {
|
||||
_version: 1;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@@ -10,11 +9,10 @@ import {
|
||||
useIPAdapterModels,
|
||||
useLoRAModels,
|
||||
useMainModels,
|
||||
useRefinerModels,
|
||||
useT2IAdapterModels,
|
||||
useVAEModels,
|
||||
} from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
||||
|
||||
import { FetchingModelsLoader } from './FetchingModelsLoader';
|
||||
import { ModelListWrapper } from './ModelListWrapper';
|
||||
@@ -29,12 +27,6 @@ const ModelList = () => {
|
||||
[mainModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [refinerModels, { isLoading: isLoadingRefinerModels }] = useRefinerModels();
|
||||
const filteredRefinerModels = useMemo(
|
||||
() => modelsFilter(refinerModels, searchTerm, filteredModelType),
|
||||
[refinerModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
|
||||
const filteredLoRAModels = useMemo(
|
||||
() => modelsFilter(loraModels, searchTerm, filteredModelType),
|
||||
@@ -71,28 +63,6 @@ const ModelList = () => {
|
||||
[vaeModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const totalFilteredModels = useMemo(() => {
|
||||
return (
|
||||
filteredMainModels.length +
|
||||
filteredRefinerModels.length +
|
||||
filteredLoRAModels.length +
|
||||
filteredEmbeddingModels.length +
|
||||
filteredControlNetModels.length +
|
||||
filteredT2IAdapterModels.length +
|
||||
filteredIPAdapterModels.length +
|
||||
filteredVAEModels.length
|
||||
);
|
||||
}, [
|
||||
filteredControlNetModels.length,
|
||||
filteredEmbeddingModels.length,
|
||||
filteredIPAdapterModels.length,
|
||||
filteredLoRAModels.length,
|
||||
filteredMainModels.length,
|
||||
filteredRefinerModels.length,
|
||||
filteredT2IAdapterModels.length,
|
||||
filteredVAEModels.length,
|
||||
]);
|
||||
|
||||
return (
|
||||
<ScrollableContent>
|
||||
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
||||
@@ -101,11 +71,6 @@ const ModelList = () => {
|
||||
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
||||
<ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" />
|
||||
)}
|
||||
{/* Refiner Model List */}
|
||||
{isLoadingRefinerModels && <FetchingModelsLoader loadingMessage="Loading Refiner Models..." />}
|
||||
{!isLoadingRefinerModels && filteredRefinerModels.length > 0 && (
|
||||
<ModelListWrapper title={t('sdxl.refiner')} modelList={filteredRefinerModels} key="refiner" />
|
||||
)}
|
||||
{/* LoRAs List */}
|
||||
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
||||
@@ -143,11 +108,6 @@ const ModelList = () => {
|
||||
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||
)}
|
||||
{totalFilteredModels === 0 && (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Text>{t('modelManager.noMatchingModels')}</Text>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
);
|
||||
@@ -158,24 +118,12 @@ export default memo(ModelList);
|
||||
const modelsFilter = <T extends AnyModelConfig>(
|
||||
data: T[],
|
||||
nameFilter: string,
|
||||
filteredModelType: FilterableModelType | null
|
||||
filteredModelType: ModelType | null
|
||||
): T[] => {
|
||||
return data.filter((model) => {
|
||||
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
||||
const matchesType = getMatchesType(model, filteredModelType);
|
||||
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
||||
|
||||
return matchesFilter && matchesType;
|
||||
});
|
||||
};
|
||||
|
||||
const getMatchesType = (modelConfig: AnyModelConfig, filteredModelType: FilterableModelType | null): boolean => {
|
||||
if (filteredModelType === 'refiner') {
|
||||
return modelConfig.base === 'sdxl-refiner';
|
||||
}
|
||||
|
||||
if (filteredModelType === 'main' && modelConfig.base === 'sdxl-refiner') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return filteredModelType ? modelConfig.type === filteredModelType : true;
|
||||
};
|
||||
|
||||
@@ -13,7 +13,6 @@ export const ModelTypeFilter = () => {
|
||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
||||
() => ({
|
||||
main: t('modelManager.main'),
|
||||
refiner: t('sdxl.refiner'),
|
||||
lora: 'LoRA',
|
||||
embedding: t('modelManager.textualInversions'),
|
||||
controlnet: 'ControlNet',
|
||||
|
||||
@@ -65,11 +65,6 @@ export const buildCanvasOutpaintGraph = async (
|
||||
infillTileSize,
|
||||
infillPatchmatchDownscaleSize,
|
||||
infillMethod,
|
||||
// infillMosaicTileWidth,
|
||||
// infillMosaicTileHeight,
|
||||
// infillMosaicMinColor,
|
||||
// infillMosaicMaxColor,
|
||||
infillColorValue,
|
||||
clipSkip,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
@@ -361,28 +356,6 @@ export const buildCanvasOutpaintGraph = async (
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: add mosaic back
|
||||
// if (infillMethod === 'mosaic') {
|
||||
// graph.nodes[INPAINT_INFILL] = {
|
||||
// type: 'infill_mosaic',
|
||||
// id: INPAINT_INFILL,
|
||||
// is_intermediate,
|
||||
// tile_width: infillMosaicTileWidth,
|
||||
// tile_height: infillMosaicTileHeight,
|
||||
// min_color: infillMosaicMinColor,
|
||||
// max_color: infillMosaicMaxColor,
|
||||
// };
|
||||
// }
|
||||
|
||||
if (infillMethod === 'color') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_rgba',
|
||||
id: INPAINT_INFILL,
|
||||
color: infillColorValue,
|
||||
is_intermediate,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Scale Before Processing
|
||||
if (isUsingScaledDimensions) {
|
||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||
|
||||
@@ -66,11 +66,6 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
infillTileSize,
|
||||
infillPatchmatchDownscaleSize,
|
||||
infillMethod,
|
||||
// infillMosaicTileWidth,
|
||||
// infillMosaicTileHeight,
|
||||
// infillMosaicMinColor,
|
||||
// infillMosaicMaxColor,
|
||||
infillColorValue,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
canvasCoherenceMode,
|
||||
@@ -370,28 +365,6 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: add mosaic back
|
||||
// if (infillMethod === 'mosaic') {
|
||||
// graph.nodes[INPAINT_INFILL] = {
|
||||
// type: 'infill_mosaic',
|
||||
// id: INPAINT_INFILL,
|
||||
// is_intermediate,
|
||||
// tile_width: infillMosaicTileWidth,
|
||||
// tile_height: infillMosaicTileHeight,
|
||||
// min_color: infillMosaicMinColor,
|
||||
// max_color: infillMosaicMaxColor,
|
||||
// };
|
||||
// }
|
||||
|
||||
if (infillMethod === 'color') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_rgba',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate,
|
||||
color: infillColorValue,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Scale Before Processing
|
||||
if (isUsingScaledDimensions) {
|
||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamInfillColorOptions = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectGenerationSlice, (generation) => ({
|
||||
infillColor: generation.infillColorValue,
|
||||
})),
|
||||
[]
|
||||
);
|
||||
|
||||
const { infillColor } = useAppSelector(selector);
|
||||
|
||||
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleInfillColor = useCallback(
|
||||
(v: RgbaColor) => {
|
||||
dispatch(setInfillColorValue(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<FormControl isDisabled={infillMethod !== 'color'}>
|
||||
<FormLabel>{t('parameters.infillColorValue')}</FormLabel>
|
||||
<Box w="full" pt={2} pb={2}>
|
||||
<IAIColorPicker color={infillColor} onChange={handleInfillColor} />
|
||||
</Box>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamInfillColorOptions);
|
||||
@@ -1,127 +0,0 @@
|
||||
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||
import {
|
||||
selectGenerationSlice,
|
||||
setInfillMosaicMaxColor,
|
||||
setInfillMosaicMinColor,
|
||||
setInfillMosaicTileHeight,
|
||||
setInfillMosaicTileWidth,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamInfillMosaicTileSize = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectGenerationSlice, (generation) => ({
|
||||
infillMosaicTileWidth: generation.infillMosaicTileWidth,
|
||||
infillMosaicTileHeight: generation.infillMosaicTileHeight,
|
||||
infillMosaicMinColor: generation.infillMosaicMinColor,
|
||||
infillMosaicMaxColor: generation.infillMosaicMaxColor,
|
||||
})),
|
||||
[]
|
||||
);
|
||||
|
||||
const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleInfillMosaicTileWidthChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(setInfillMosaicTileWidth(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleInfillMosaicTileHeightChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(setInfillMosaicTileHeight(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleInfillMosaicMinColor = useCallback(
|
||||
(v: RgbaColor) => {
|
||||
dispatch(setInfillMosaicMinColor(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleInfillMosaicMaxColor = useCallback(
|
||||
(v: RgbaColor) => {
|
||||
dispatch(setInfillMosaicMaxColor(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||
<FormLabel>{t('parameters.infillMosaicTileWidth')}</FormLabel>
|
||||
<CompositeSlider
|
||||
min={8}
|
||||
max={256}
|
||||
value={infillMosaicTileWidth}
|
||||
defaultValue={64}
|
||||
onChange={handleInfillMosaicTileWidthChange}
|
||||
step={8}
|
||||
fineStep={8}
|
||||
marks
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
min={8}
|
||||
max={1024}
|
||||
value={infillMosaicTileWidth}
|
||||
defaultValue={64}
|
||||
onChange={handleInfillMosaicTileWidthChange}
|
||||
step={8}
|
||||
fineStep={8}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||
<FormLabel>{t('parameters.infillMosaicTileHeight')}</FormLabel>
|
||||
<CompositeSlider
|
||||
min={8}
|
||||
max={256}
|
||||
value={infillMosaicTileHeight}
|
||||
defaultValue={64}
|
||||
onChange={handleInfillMosaicTileHeightChange}
|
||||
step={8}
|
||||
fineStep={8}
|
||||
marks
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
min={8}
|
||||
max={1024}
|
||||
value={infillMosaicTileHeight}
|
||||
defaultValue={64}
|
||||
onChange={handleInfillMosaicTileHeightChange}
|
||||
step={8}
|
||||
fineStep={8}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||
<FormLabel>{t('parameters.infillMosaicMinColor')}</FormLabel>
|
||||
<Box w="full" pt={2} pb={2}>
|
||||
<IAIColorPicker color={infillMosaicMinColor} onChange={handleInfillMosaicMinColor} />
|
||||
</Box>
|
||||
</FormControl>
|
||||
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||
<FormLabel>{t('parameters.infillMosaicMaxColor')}</FormLabel>
|
||||
<Box w="full" pt={2} pb={2}>
|
||||
<IAIColorPicker color={infillMosaicMaxColor} onChange={handleInfillMosaicMaxColor} />
|
||||
</Box>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamInfillMosaicTileSize);
|
||||
@@ -1,8 +1,6 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
|
||||
import ParamInfillColorOptions from './ParamInfillColorOptions';
|
||||
import ParamInfillMosaicOptions from './ParamInfillMosaicOptions';
|
||||
import ParamInfillPatchmatchDownscaleSize from './ParamInfillPatchmatchDownscaleSize';
|
||||
import ParamInfillTilesize from './ParamInfillTilesize';
|
||||
|
||||
@@ -16,14 +14,6 @@ const ParamInfillOptions = () => {
|
||||
return <ParamInfillPatchmatchDownscaleSize />;
|
||||
}
|
||||
|
||||
if (infillMethod === 'mosaic') {
|
||||
return <ParamInfillMosaicOptions />;
|
||||
}
|
||||
|
||||
if (infillMethod === 'color') {
|
||||
return <ParamInfillColorOptions />;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import type {
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import type { GenerationState } from './types';
|
||||
@@ -44,6 +43,8 @@ const initialGenerationState: GenerationState = {
|
||||
shouldFitToWidthHeight: true,
|
||||
shouldRandomizeSeed: true,
|
||||
steps: 50,
|
||||
infillTileSize: 32,
|
||||
infillPatchmatchDownscaleSize: 1,
|
||||
width: 512,
|
||||
model: null,
|
||||
vae: null,
|
||||
@@ -54,13 +55,6 @@ const initialGenerationState: GenerationState = {
|
||||
shouldUseCpuNoise: true,
|
||||
shouldShowAdvancedOptions: false,
|
||||
aspectRatio: { ...initialAspectRatioState },
|
||||
infillTileSize: 32,
|
||||
infillPatchmatchDownscaleSize: 1,
|
||||
infillMosaicTileWidth: 64,
|
||||
infillMosaicTileHeight: 64,
|
||||
infillMosaicMinColor: { r: 0, g: 0, b: 0, a: 1 },
|
||||
infillMosaicMaxColor: { r: 255, g: 255, b: 255, a: 1 },
|
||||
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
|
||||
};
|
||||
|
||||
export const generationSlice = createSlice({
|
||||
@@ -122,6 +116,15 @@ export const generationSlice = createSlice({
|
||||
setCanvasCoherenceMinDenoise: (state, action: PayloadAction<number>) => {
|
||||
state.canvasCoherenceMinDenoise = action.payload;
|
||||
},
|
||||
setInfillMethod: (state, action: PayloadAction<string>) => {
|
||||
state.infillMethod = action.payload;
|
||||
},
|
||||
setInfillTileSize: (state, action: PayloadAction<number>) => {
|
||||
state.infillTileSize = action.payload;
|
||||
},
|
||||
setInfillPatchmatchDownscaleSize: (state, action: PayloadAction<number>) => {
|
||||
state.infillPatchmatchDownscaleSize = action.payload;
|
||||
},
|
||||
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
|
||||
const { image_name, width, height } = action.payload;
|
||||
state.initialImage = { imageName: image_name, width, height };
|
||||
@@ -203,30 +206,6 @@ export const generationSlice = createSlice({
|
||||
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
|
||||
state.aspectRatio = action.payload;
|
||||
},
|
||||
setInfillMethod: (state, action: PayloadAction<string>) => {
|
||||
state.infillMethod = action.payload;
|
||||
},
|
||||
setInfillTileSize: (state, action: PayloadAction<number>) => {
|
||||
state.infillTileSize = action.payload;
|
||||
},
|
||||
setInfillPatchmatchDownscaleSize: (state, action: PayloadAction<number>) => {
|
||||
state.infillPatchmatchDownscaleSize = action.payload;
|
||||
},
|
||||
setInfillMosaicTileWidth: (state, action: PayloadAction<number>) => {
|
||||
state.infillMosaicTileWidth = action.payload;
|
||||
},
|
||||
setInfillMosaicTileHeight: (state, action: PayloadAction<number>) => {
|
||||
state.infillMosaicTileHeight = action.payload;
|
||||
},
|
||||
setInfillMosaicMinColor: (state, action: PayloadAction<RgbaColor>) => {
|
||||
state.infillMosaicMinColor = action.payload;
|
||||
},
|
||||
setInfillMosaicMaxColor: (state, action: PayloadAction<RgbaColor>) => {
|
||||
state.infillMosaicMaxColor = action.payload;
|
||||
},
|
||||
setInfillColorValue: (state, action: PayloadAction<RgbaColor>) => {
|
||||
state.infillColorValue = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(configChanged, (state, action) => {
|
||||
@@ -270,6 +249,8 @@ export const {
|
||||
setShouldFitToWidthHeight,
|
||||
setShouldRandomizeSeed,
|
||||
setSteps,
|
||||
setInfillTileSize,
|
||||
setInfillPatchmatchDownscaleSize,
|
||||
initialImageChanged,
|
||||
modelChanged,
|
||||
vaeSelected,
|
||||
@@ -283,13 +264,6 @@ export const {
|
||||
heightChanged,
|
||||
widthRecalled,
|
||||
heightRecalled,
|
||||
setInfillTileSize,
|
||||
setInfillPatchmatchDownscaleSize,
|
||||
setInfillMosaicTileWidth,
|
||||
setInfillMosaicTileHeight,
|
||||
setInfillMosaicMinColor,
|
||||
setInfillMosaicMaxColor,
|
||||
setInfillColorValue,
|
||||
} = generationSlice.actions;
|
||||
|
||||
export const { selectOptimalDimension } = generationSlice.selectors;
|
||||
|
||||
@@ -17,7 +17,6 @@ import type {
|
||||
ParameterVAEModel,
|
||||
ParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
|
||||
export interface GenerationState {
|
||||
_version: 2;
|
||||
@@ -40,6 +39,8 @@ export interface GenerationState {
|
||||
shouldFitToWidthHeight: boolean;
|
||||
shouldRandomizeSeed: boolean;
|
||||
steps: ParameterSteps;
|
||||
infillTileSize: number;
|
||||
infillPatchmatchDownscaleSize: number;
|
||||
width: ParameterWidth;
|
||||
model: ParameterModel | null;
|
||||
vae: ParameterVAEModel | null;
|
||||
@@ -50,13 +51,6 @@ export interface GenerationState {
|
||||
shouldUseCpuNoise: boolean;
|
||||
shouldShowAdvancedOptions: boolean;
|
||||
aspectRatio: AspectRatioState;
|
||||
infillTileSize: number;
|
||||
infillPatchmatchDownscaleSize: number;
|
||||
infillMosaicTileWidth: number;
|
||||
infillMosaicTileHeight: number;
|
||||
infillMosaicMinColor: RgbaColor;
|
||||
infillMosaicMaxColor: RgbaColor;
|
||||
infillColorValue: RgbaColor;
|
||||
}
|
||||
|
||||
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
|
||||
import type { InvokeTabName } from './tabMap';
|
||||
@@ -46,9 +45,6 @@ export const uiSlice = createSlice({
|
||||
builder.addCase(initialImageChanged, (state) => {
|
||||
state.activeTab = 'img2img';
|
||||
});
|
||||
builder.addCase(workflowLoadRequested, (state) => {
|
||||
state.activeTab = 'nodes';
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ from invokeai.app.invocations.fields import (
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
)
|
||||
@@ -106,7 +105,6 @@ __all__ = [
|
||||
"OutputField",
|
||||
"UIComponent",
|
||||
"UIType",
|
||||
"WithBoard",
|
||||
"WithMetadata",
|
||||
"WithWorkflow",
|
||||
# invokeai.app.invocations.latent
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.0.4"
|
||||
__version__ = "4.0.2"
|
||||
|
||||
@@ -5,8 +5,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.lora.lora_layer import LoRALayer
|
||||
from invokeai.backend.lora.lora_model import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_model_patcher import LoraModelPatcher
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -44,7 +45,7 @@ def test_apply_lora(device):
|
||||
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||
expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
|
||||
|
||||
with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
|
||||
with LoraModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
@@ -86,7 +87,7 @@ def test_apply_lora_change_device():
|
||||
|
||||
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||
|
||||
with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
|
||||
with LoraModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||