Compare commits
21 Commits
ryan/peft-
...
v4.0.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e330966020 | ||
|
|
b783679b9f | ||
|
|
d32e557e50 | ||
|
|
90686c7f9c | ||
|
|
4571986c63 | ||
|
|
fec989f015 | ||
|
|
b5c048d8bf | ||
|
|
577469be55 | ||
|
|
812f10730f | ||
|
|
3006285d13 | ||
|
|
5d4a571778 | ||
|
|
90bdd74f30 | ||
|
|
d6ccd5bc81 | ||
|
|
f0b1bb0327 | ||
|
|
b061db414f | ||
|
|
e55ab5b3a1 | ||
|
|
adb7966bb3 | ||
|
|
3c195d74a5 | ||
|
|
32a6b758cd | ||
|
|
3659219f46 | ||
|
|
d284e0567a |
@@ -12,7 +12,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
from invokeai.backend.util.logging import logging
|
||||
from invokeai.version import __version__
|
||||
@@ -100,7 +100,7 @@ async def get_app_deps() -> AppDependencyVersions:
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ["tile", "lama", "cv2"]
|
||||
infill_methods = ["tile", "lama", "cv2", "color"] # TODO: add mosaic back
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append("patchmatch")
|
||||
|
||||
|
||||
@@ -9,9 +9,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.peft.peft_model import PeftModel
|
||||
from invokeai.backend.peft.peft_model_patcher import PeftModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
@@ -62,12 +61,15 @@ class CompelInvocation(BaseInvocation):
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[PeftModel, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, PeftModel)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
||||
|
||||
@@ -78,7 +80,7 @@ class CompelInvocation(BaseInvocation):
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
PeftModelPatcher.apply_peft_model_to_text_encoder(text_encoder, _lora_loader(), "text_encoder"),
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||
):
|
||||
@@ -159,13 +161,16 @@ class SDXLPromptInvocationBase:
|
||||
c_pooled = None
|
||||
return c, c_pooled, None
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[PeftModel, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, PeftModel)
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
|
||||
|
||||
@@ -176,7 +181,7 @@ class SDXLPromptInvocationBase:
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
PeftModelPatcher.apply_peft_model_to_text_encoder(text_encoder, _lora_loader(), lora_prefix),
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||
):
|
||||
@@ -254,15 +259,15 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
context, self.clip, self.prompt, False, "text_encoder", zero_on_empty=True
|
||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||
)
|
||||
if self.style.strip() == "":
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
context, self.clip2, self.prompt, True, "text_encoder_2", zero_on_empty=True
|
||||
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
||||
)
|
||||
else:
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
||||
context, self.clip2, self.style, True, "text_encoder_2", zero_on_empty=True
|
||||
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
|
||||
)
|
||||
|
||||
original_size = (self.original_height, self.original_width)
|
||||
|
||||
@@ -1,154 +1,91 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from abc import abstractmethod
|
||||
from typing import Literal, get_args
|
||||
|
||||
import math
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.infill_methods.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.infill_methods.lama import LaMA
|
||||
from invokeai.backend.image_util.infill_methods.mosaic import infill_mosaic
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch, infill_patchmatch
|
||||
from invokeai.backend.image_util.infill_methods.tile import infill_tile
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = ["tile", "solid", "lama", "cv2"]
|
||||
|
||||
def get_infill_methods():
|
||||
methods = Literal["tile", "color", "lama", "cv2"] # TODO: add mosaic back
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
methods = Literal["patchmatch", "tile", "color", "lama", "cv2"] # TODO: add mosaic back
|
||||
return methods
|
||||
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
INFILL_METHODS = get_infill_methods()
|
||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
|
||||
|
||||
def infill_lama(im: Image.Image) -> Image.Image:
|
||||
lama = LaMA()
|
||||
return lama(im)
|
||||
class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Base class for invocations that preprocess images for Infilling"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
@abstractmethod
|
||||
def infill(self, image: Image.Image) -> Image.Image:
|
||||
"""Infill the image with the specified method"""
|
||||
pass
|
||||
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
if not PatchMatch.patchmatch_available():
|
||||
return im
|
||||
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
|
||||
"""Process the image to have an alpha channel before being infilled"""
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
has_alpha = True if image.mode == "RGBA" else False
|
||||
return image, has_alpha
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Retrieve and process image to be infilled
|
||||
input_image, has_alpha = self.load_image(context)
|
||||
|
||||
# If the input image has no alpha channel, return it
|
||||
if has_alpha is False:
|
||||
return ImageOutput.build(context.images.get_dto(self.image.image_name))
|
||||
|
||||
def infill_cv2(im: Image.Image) -> Image.Image:
|
||||
return cv2_inpaint(im)
|
||||
# Perform Infill action
|
||||
infilled_image = self.infill(input_image)
|
||||
|
||||
# Create ImageDTO for Infilled Image
|
||||
infilled_image_dto = context.images.save(image=infilled_image)
|
||||
|
||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
|
||||
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = tiles_mask > 0
|
||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum() # noqa: E712
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1, 2)
|
||||
st = tiles_all.reshape(
|
||||
(
|
||||
math.prod(tiles_all.shape[0:2]),
|
||||
math.prod(tiles_all.shape[2:4]),
|
||||
tiles_all.shape[4],
|
||||
)
|
||||
)
|
||||
si = Image.fromarray(st, mode="RGBA")
|
||||
|
||||
return si
|
||||
# Return Infilled Image
|
||||
return ImageOutput.build(infilled_image_dto)
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class InfillColorInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
color: ColorField = InputField(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return infilled
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3")
|
||||
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class InfillTileInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
@@ -157,92 +94,74 @@ class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
def infill(self, image: Image.Image):
|
||||
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
|
||||
return output.infilled
|
||||
|
||||
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2"
|
||||
)
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name).convert("RGBA")
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
infill_image = image.copy()
|
||||
width = int(image.width / self.downscale)
|
||||
height = int(image.height / self.downscale)
|
||||
infill_image = infill_image.resize(
|
||||
|
||||
infilled = image.resize(
|
||||
(width, height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(infill_image)
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
infilled = infill_patchmatch(image)
|
||||
infilled = infilled.resize(
|
||||
(image.width, image.height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return infilled
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Downloads the LaMa model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=context.config.get().models_path / "core/misc/lama/lama.pt",
|
||||
)
|
||||
|
||||
infilled = infill_lama(image.copy())
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
def infill(self, image: Image.Image):
|
||||
lama = LaMA()
|
||||
return lama(image)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class CV2InfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
return cv2_inpaint(image)
|
||||
|
||||
|
||||
# @invocation(
|
||||
# "infill_mosaic", title="Mosaic Infill", tags=["image", "inpaint", "outpaint"], category="inpaint", version="1.0.0"
|
||||
# )
|
||||
class MosaicInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image with a mosaic pattern drawing colors from the rest of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
tile_width: int = InputField(default=64, description="Width of the tile")
|
||||
tile_height: int = InputField(default=64, description="Height of the tile")
|
||||
min_color: ColorField = InputField(
|
||||
default=ColorField(r=0, g=0, b=0, a=255),
|
||||
description="The min threshold for color",
|
||||
)
|
||||
max_color: ColorField = InputField(
|
||||
default=ColorField(r=255, g=255, b=255, a=255),
|
||||
description="The max threshold for color",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = infill_cv2(image.copy())
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
def infill(self, image: Image.Image):
|
||||
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())
|
||||
|
||||
@@ -65,7 +65,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ui_order=-1,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
)
|
||||
clip_vision_model: Literal["auto", "ViT-H", "ViT-G"] = InputField(
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
|
||||
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
|
||||
default="auto",
|
||||
ui_order=2,
|
||||
@@ -96,14 +96,9 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
|
||||
|
||||
if self.clip_vision_model == "auto":
|
||||
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"You need to set the appropriate CLIP Vision model for checkpoint IP Adapter models."
|
||||
)
|
||||
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
else:
|
||||
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||
|
||||
|
||||
@@ -48,10 +48,9 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.peft.peft_model import PeftModel
|
||||
from invokeai.backend.peft.peft_model_patcher import PeftModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
@@ -715,12 +714,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[PeftModel, float]]:
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, PeftModel)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
@@ -730,7 +730,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
||||
unet_info as unet,
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
PeftModelPatcher.apply_peft_model_to_unet(unet, _lora_loader()),
|
||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
@@ -1254,7 +1254,7 @@ class IdealSizeInvocation(BaseInvocation):
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||
unet_config = context.models.get_config(**self.unet.unet.model_dump())
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
aspect = self.width / self.height
|
||||
dimension: float = 512
|
||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||
|
||||
@@ -5,8 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
@@ -6,8 +6,7 @@ from typing import Optional, Type
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
ModelLoaderRegistry,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Initialization file for model manager service."""
|
||||
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
|
||||
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
|
||||
@@ -8,6 +8,7 @@ from .model_manager_default import ModelManagerService, ModelManagerServiceBase
|
||||
__all__ = [
|
||||
"ModelManagerServiceBase",
|
||||
"ModelManagerService",
|
||||
"AnyModel",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
|
||||
@@ -80,6 +80,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
max_vram_cache_size=app_config.vram,
|
||||
lazy_offloading=app_config.lazy_offload,
|
||||
logger=logger,
|
||||
execution_device=execution_device,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Initialization file for invokeai.backend.image_util methods.
|
||||
"""
|
||||
|
||||
from .patchmatch import PatchMatch # noqa: F401
|
||||
from .infill_methods.patchmatch import PatchMatch # noqa: F401
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||
from .seamless import configure_model_padding # noqa: F401
|
||||
from .util import InitImageResizer, make_grid # noqa: F401
|
||||
|
||||
@@ -7,6 +7,7 @@ from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
|
||||
@@ -30,6 +31,14 @@ class LaMA:
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = choose_torch_device()
|
||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||
|
||||
if not model_location.exists():
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=model_location,
|
||||
)
|
||||
|
||||
model = load_jit_model(model_location, device)
|
||||
|
||||
image = np.asarray(input_image.convert("RGB"))
|
||||
60
invokeai/backend/image_util/infill_methods/mosaic.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def infill_mosaic(
|
||||
image: Image.Image,
|
||||
tile_shape: Tuple[int, int] = (64, 64),
|
||||
min_color: Tuple[int, int, int, int] = (0, 0, 0, 0),
|
||||
max_color: Tuple[int, int, int, int] = (255, 255, 255, 0),
|
||||
) -> Image.Image:
|
||||
"""
|
||||
image:PIL - A PIL Image
|
||||
tile_shape: Tuple[int,int] - Tile width & Tile Height
|
||||
min_color: Tuple[int,int,int] - RGB values for the lowest color to clip to (0-255)
|
||||
max_color: Tuple[int,int,int] - RGB values for the highest color to clip to (0-255)
|
||||
"""
|
||||
|
||||
np_image = np.array(image) # Convert image to np array
|
||||
alpha = np_image[:, :, 3] # Get the mask from the alpha channel of the image
|
||||
non_transparent_pixels = np_image[alpha != 0, :3] # List of non-transparent pixels
|
||||
|
||||
# Create color tiles to paste in the empty areas of the image
|
||||
tile_width, tile_height = tile_shape
|
||||
|
||||
# Clip the range of colors in the image to a particular spectrum only
|
||||
r_min, g_min, b_min, _ = min_color
|
||||
r_max, g_max, b_max, _ = max_color
|
||||
non_transparent_pixels[:, 0] = np.clip(non_transparent_pixels[:, 0], r_min, r_max)
|
||||
non_transparent_pixels[:, 1] = np.clip(non_transparent_pixels[:, 1], g_min, g_max)
|
||||
non_transparent_pixels[:, 2] = np.clip(non_transparent_pixels[:, 2], b_min, b_max)
|
||||
|
||||
tiles = []
|
||||
for _ in range(256):
|
||||
color = non_transparent_pixels[np.random.randint(len(non_transparent_pixels))]
|
||||
tile = np.zeros((tile_height, tile_width, 3), dtype=np.uint8)
|
||||
tile[:, :] = color
|
||||
tiles.append(tile)
|
||||
|
||||
# Fill the transparent area with tiles
|
||||
filled_image = np.zeros((image.height, image.width, 3), dtype=np.uint8)
|
||||
|
||||
for x in range(image.width):
|
||||
for y in range(image.height):
|
||||
tile = tiles[np.random.randint(len(tiles))]
|
||||
try:
|
||||
filled_image[
|
||||
y - (y % tile_height) : y - (y % tile_height) + tile_height,
|
||||
x - (x % tile_width) : x - (x % tile_width) + tile_width,
|
||||
] = tile
|
||||
except ValueError:
|
||||
# Need to handle edge cases - literally
|
||||
pass
|
||||
|
||||
filled_image = Image.fromarray(filled_image) # Convert the filled tiles image to PIL
|
||||
image = Image.composite(
|
||||
image, filled_image, image.split()[-1]
|
||||
) # Composite the original image on top of the filled tiles
|
||||
return image
|
||||
67
invokeai/backend/image_util/infill_methods/patchmatch.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
This module defines a singleton object, "patchmatch" that
|
||||
wraps the actual patchmatch object. It respects the global
|
||||
"try_patchmatch" attribute, so that patchmatch loading can
|
||||
be suppressed or deferred
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
Thin class wrapper around the patchmatch function.
|
||||
"""
|
||||
|
||||
patch_match = None
|
||||
tried_load: bool = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _load_patch_match(cls):
|
||||
if cls.tried_load:
|
||||
return
|
||||
if get_config().patchmatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
logger.info("Patchmatch initialized")
|
||||
cls.patch_match = pm
|
||||
else:
|
||||
logger.info("Patchmatch not loaded (nonfatal)")
|
||||
else:
|
||||
logger.info("Patchmatch loading disabled")
|
||||
cls.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def patchmatch_available(cls) -> bool:
|
||||
cls._load_patch_match()
|
||||
if not cls.patch_match:
|
||||
return False
|
||||
return cls.patch_match.patchmatch_available
|
||||
|
||||
@classmethod
|
||||
def inpaint(cls, image: Image.Image) -> Image.Image:
|
||||
if cls.patch_match is None or not cls.patchmatch_available():
|
||||
return image
|
||||
|
||||
np_image = np.array(image)
|
||||
mask = 255 - np_image[:, :, 3]
|
||||
infilled = cls.patch_match.inpaint(np_image[:, :, :3], mask, patch_size=3)
|
||||
return Image.fromarray(infilled, mode="RGB")
|
||||
|
||||
|
||||
def infill_patchmatch(image: Image.Image) -> Image.Image:
|
||||
IS_PATCHMATCH_AVAILABLE = PatchMatch.patchmatch_available()
|
||||
|
||||
if not IS_PATCHMATCH_AVAILABLE:
|
||||
logger.warning("PatchMatch is not available on this system")
|
||||
return image
|
||||
|
||||
return PatchMatch.inpaint(image)
|
||||
|
After Width: | Height: | Size: 45 KiB |
|
After Width: | Height: | Size: 2.2 KiB |
|
After Width: | Height: | Size: 36 KiB |
|
After Width: | Height: | Size: 33 KiB |
|
After Width: | Height: | Size: 21 KiB |
|
After Width: | Height: | Size: 39 KiB |
|
After Width: | Height: | Size: 42 KiB |
|
After Width: | Height: | Size: 48 KiB |
|
After Width: | Height: | Size: 49 KiB |
|
After Width: | Height: | Size: 60 KiB |
95
invokeai/backend/image_util/infill_methods/tile.ipynb
Normal file
@@ -0,0 +1,95 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"Smoke test for the tile infill\"\"\"\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import Optional\n",
|
||||
"from PIL import Image\n",
|
||||
"from invokeai.backend.image_util.infill_methods.tile import infill_tile\n",
|
||||
"\n",
|
||||
"images: list[tuple[str, Image.Image]] = []\n",
|
||||
"\n",
|
||||
"for i in sorted(Path(\"./test_images/\").glob(\"*.webp\")):\n",
|
||||
" images.append((i.name, Image.open(i)))\n",
|
||||
" images.append((i.name, Image.open(i).transpose(Image.FLIP_LEFT_RIGHT)))\n",
|
||||
" images.append((i.name, Image.open(i).transpose(Image.FLIP_TOP_BOTTOM)))\n",
|
||||
" images.append((i.name, Image.open(i).resize((512, 512))))\n",
|
||||
" images.append((i.name, Image.open(i).resize((1234, 461))))\n",
|
||||
"\n",
|
||||
"outputs: list[tuple[str, Image.Image, Image.Image, Optional[Image.Image]]] = []\n",
|
||||
"\n",
|
||||
"for name, image in images:\n",
|
||||
" try:\n",
|
||||
" output = infill_tile(image, seed=0, tile_size=32)\n",
|
||||
" outputs.append((name, image, output.infilled, output.tile_image))\n",
|
||||
" except ValueError as e:\n",
|
||||
" print(f\"Skipping image {name}: {e}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Display the images in jupyter notebook\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from PIL import ImageOps\n",
|
||||
"\n",
|
||||
"fig, axes = plt.subplots(len(outputs), 3, figsize=(10, 3 * len(outputs)))\n",
|
||||
"plt.subplots_adjust(hspace=0)\n",
|
||||
"\n",
|
||||
"for i, (name, original, infilled, tile_image) in enumerate(outputs):\n",
|
||||
" # Add a border to each image, helps to see the edges\n",
|
||||
" size = original.size\n",
|
||||
" original = ImageOps.expand(original, border=5, fill=\"red\")\n",
|
||||
" filled = ImageOps.expand(infilled, border=5, fill=\"red\")\n",
|
||||
" if tile_image:\n",
|
||||
" tile_image = ImageOps.expand(tile_image, border=5, fill=\"red\")\n",
|
||||
"\n",
|
||||
" axes[i, 0].imshow(original)\n",
|
||||
" axes[i, 0].axis(\"off\")\n",
|
||||
" axes[i, 0].set_title(f\"Original ({name} - {size})\")\n",
|
||||
"\n",
|
||||
" if tile_image:\n",
|
||||
" axes[i, 1].imshow(tile_image)\n",
|
||||
" axes[i, 1].axis(\"off\")\n",
|
||||
" axes[i, 1].set_title(\"Tile Image\")\n",
|
||||
" else:\n",
|
||||
" axes[i, 1].axis(\"off\")\n",
|
||||
" axes[i, 1].set_title(\"NO TILES GENERATED (NO TRANSPARENCY)\")\n",
|
||||
"\n",
|
||||
" axes[i, 2].imshow(filled)\n",
|
||||
" axes[i, 2].axis(\"off\")\n",
|
||||
" axes[i, 2].set_title(\"Filled\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".invokeai",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
122
invokeai/backend/image_util/infill_methods/tile.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def create_tile_pool(img_array: np.ndarray, tile_size: tuple[int, int]) -> list[np.ndarray]:
|
||||
"""
|
||||
Create a pool of tiles from non-transparent areas of the image by systematically walking through the image.
|
||||
|
||||
Args:
|
||||
img_array: numpy array of the image.
|
||||
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
|
||||
|
||||
Returns:
|
||||
A list of numpy arrays, each representing a tile.
|
||||
"""
|
||||
tiles: list[np.ndarray] = []
|
||||
rows, cols = img_array.shape[:2]
|
||||
tile_width, tile_height = tile_size
|
||||
|
||||
for y in range(0, rows - tile_height + 1, tile_height):
|
||||
for x in range(0, cols - tile_width + 1, tile_width):
|
||||
tile = img_array[y : y + tile_height, x : x + tile_width]
|
||||
# Check if the image has an alpha channel and the tile is completely opaque
|
||||
if img_array.shape[2] == 4 and np.all(tile[:, :, 3] == 255):
|
||||
tiles.append(tile)
|
||||
elif img_array.shape[2] == 3: # If no alpha channel, append the tile
|
||||
tiles.append(tile)
|
||||
|
||||
if not tiles:
|
||||
raise ValueError(
|
||||
"Not enough opaque pixels to generate any tiles. Use a smaller tile size or a different image."
|
||||
)
|
||||
|
||||
return tiles
|
||||
|
||||
|
||||
def create_filled_image(
|
||||
img_array: np.ndarray, tile_pool: list[np.ndarray], tile_size: tuple[int, int], seed: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Create an image of the same dimensions as the original, filled entirely with tiles from the pool.
|
||||
|
||||
Args:
|
||||
img_array: numpy array of the original image.
|
||||
tile_pool: A list of numpy arrays, each representing a tile.
|
||||
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
|
||||
|
||||
Returns:
|
||||
A numpy array representing the filled image.
|
||||
"""
|
||||
|
||||
rows, cols, _ = img_array.shape
|
||||
tile_width, tile_height = tile_size
|
||||
|
||||
# Prep an empty RGB image
|
||||
filled_img_array = np.zeros((rows, cols, 3), dtype=img_array.dtype)
|
||||
|
||||
# Make the random tile selection reproducible
|
||||
rng = np.random.default_rng(seed)
|
||||
|
||||
for y in range(0, rows, tile_height):
|
||||
for x in range(0, cols, tile_width):
|
||||
# Pick a random tile from the pool
|
||||
tile = tile_pool[rng.integers(len(tile_pool))]
|
||||
|
||||
# Calculate the space available (may be less than tile size near the edges)
|
||||
space_y = min(tile_height, rows - y)
|
||||
space_x = min(tile_width, cols - x)
|
||||
|
||||
# Crop the tile if necessary to fit into the available space
|
||||
cropped_tile = tile[:space_y, :space_x, :3]
|
||||
|
||||
# Fill the available space with the (possibly cropped) tile
|
||||
filled_img_array[y : y + space_y, x : x + space_x, :3] = cropped_tile
|
||||
|
||||
return filled_img_array
|
||||
|
||||
|
||||
@dataclass
|
||||
class InfillTileOutput:
|
||||
infilled: Image.Image
|
||||
tile_image: Optional[Image.Image] = None
|
||||
|
||||
|
||||
def infill_tile(image_to_infill: Image.Image, seed: int, tile_size: int) -> InfillTileOutput:
|
||||
"""Infills an image with random tiles from the image itself.
|
||||
|
||||
If the image is not an RGBA image, it is returned untouched.
|
||||
|
||||
Args:
|
||||
image: The image to infill.
|
||||
tile_size: The size of the tiles to use for infilling.
|
||||
|
||||
Raises:
|
||||
ValueError: If there are not enough opaque pixels to generate any tiles.
|
||||
"""
|
||||
|
||||
if image_to_infill.mode != "RGBA":
|
||||
return InfillTileOutput(infilled=image_to_infill)
|
||||
|
||||
# Internally, we want a tuple of (tile_width, tile_height). In the future, the tile size can be any rectangle.
|
||||
_tile_size = (tile_size, tile_size)
|
||||
np_image = np.array(image_to_infill, dtype=np.uint8)
|
||||
|
||||
# Create the pool of tiles that we will use to infill
|
||||
tile_pool = create_tile_pool(np_image, _tile_size)
|
||||
|
||||
# Create an image from the tiles, same size as the original
|
||||
tile_np_image = create_filled_image(np_image, tile_pool, _tile_size, seed)
|
||||
|
||||
# Paste the OG image over the tile image, effectively infilling the area
|
||||
tile_image = Image.fromarray(tile_np_image, "RGB")
|
||||
infilled = tile_image.copy()
|
||||
infilled.paste(image_to_infill, (0, 0), image_to_infill.split()[-1])
|
||||
|
||||
# I think we want this to be "RGBA"?
|
||||
infilled.convert("RGBA")
|
||||
|
||||
return InfillTileOutput(infilled=infilled, tile_image=tile_image)
|
||||
@@ -1,49 +0,0 @@
|
||||
"""
|
||||
This module defines a singleton object, "patchmatch" that
|
||||
wraps the actual patchmatch object. It respects the global
|
||||
"try_patchmatch" attribute, so that patchmatch loading can
|
||||
be suppressed or deferred
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
Thin class wrapper around the patchmatch function.
|
||||
"""
|
||||
|
||||
patch_match = None
|
||||
tried_load: bool = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _load_patch_match(self):
|
||||
if self.tried_load:
|
||||
return
|
||||
if get_config().patchmatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
logger.info("Patchmatch initialized")
|
||||
else:
|
||||
logger.info("Patchmatch not loaded (nonfatal)")
|
||||
self.patch_match = pm
|
||||
else:
|
||||
logger.info("Patchmatch loading disabled")
|
||||
self.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def patchmatch_available(self) -> bool:
|
||||
self._load_patch_match()
|
||||
return self.patch_match and self.patch_match.patchmatch_available
|
||||
|
||||
@classmethod
|
||||
def inpaint(self, *args, **kwargs) -> np.ndarray:
|
||||
if self.patchmatch_available():
|
||||
return self.patch_match.inpaint(*args, **kwargs)
|
||||
@@ -12,6 +12,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||
|
||||
from ..raw_model import RawModel
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
@@ -101,7 +102,7 @@ class MLPProjModel(torch.nn.Module):
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class IPAdapter(torch.nn.Module):
|
||||
class IPAdapter(RawModel):
|
||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||
|
||||
def __init__(
|
||||
@@ -111,7 +112,6 @@ class IPAdapter(torch.nn.Module):
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_tokens: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
|
||||
from .raw_model import RawModel
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
@@ -366,13 +368,15 @@ class IA3Layer(LoRALayerBase):
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
|
||||
class LoRAModelRaw(torch.nn.Module):
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
super().__init__()
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Re-export frequently-used symbols from the Model Manager backend."""
|
||||
|
||||
from .config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
@@ -17,6 +18,7 @@ from .probe import ModelProbe
|
||||
from .search import ModelSearch
|
||||
|
||||
__all__ = [
|
||||
"AnyModel",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelRepoVariant",
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.peft.peft_model import PeftModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
AnyModel = Union[ModelMixin, torch.nn.Module, IPAdapter, PeftModel, TextualInversionModelRaw, IAIOnnxRuntimeModel]
|
||||
@@ -24,12 +24,20 @@ import time
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, TypeAlias, Union
|
||||
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
|
||||
@@ -15,7 +15,7 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
)
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from . import AnyModel
|
||||
|
||||
|
||||
def convert_ldm_vae_to_diffusers(
|
||||
|
||||
@@ -10,8 +10,8 @@ from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
@@ -7,11 +7,11 @@ from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
InvalidModelConfigException,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
|
||||
@@ -14,8 +14,7 @@ from typing import Dict, Generic, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModel, SubModelType
|
||||
|
||||
|
||||
class ModelLockerBase(ABC):
|
||||
|
||||
@@ -28,8 +28,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -422,13 +421,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _free_vram(self, device: torch.device) -> int:
|
||||
vram_device = ( # mem_get_info() needs an indexed device
|
||||
device if device.index is not None else torch.device(str(device), index=0)
|
||||
)
|
||||
free_mem, _ = torch.cuda.mem_get_info(vram_device)
|
||||
for _, cache_entry in self._cached_models.items():
|
||||
if cache_entry.loaded and not cache_entry.locked:
|
||||
free_mem += cache_entry.size
|
||||
return free_mem
|
||||
|
||||
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
|
||||
if target_device.type != "cuda":
|
||||
return
|
||||
vram_device = ( # mem_get_info() needs an indexed device
|
||||
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
|
||||
)
|
||||
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
||||
free_mem = self._free_vram(target_device)
|
||||
if needed_size > free_mem:
|
||||
needed_gb = round(needed_size / GIG, 2)
|
||||
free_gb = round(free_mem / GIG, 2)
|
||||
|
||||
@@ -4,7 +4,7 @@ Base class and implementation of a class that moves models in and out of VRAM.
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
|
||||
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
||||
|
||||
@@ -34,7 +34,6 @@ class ModelLocker(ModelLockerBase):
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
||||
self._cache_entry.lock()
|
||||
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
@@ -51,6 +50,7 @@ class ModelLocker(ModelLockerBase):
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
|
||||
return self.model
|
||||
|
||||
def unlock(self) -> None:
|
||||
|
||||
@@ -5,12 +5,12 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
@@ -16,7 +17,6 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@@ -7,9 +7,9 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
@@ -25,7 +25,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||
model_path = Path(config.path)
|
||||
model = build_ip_adapter(
|
||||
model: RawModel = build_ip_adapter(
|
||||
ip_adapter_ckpt_path=model_path,
|
||||
device=torch.device("cpu"),
|
||||
dtype=self._torch_dtype,
|
||||
|
||||
@@ -6,17 +6,17 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.peft.peft_model import PeftModel
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@@ -47,7 +47,7 @@ class LoRALoader(ModelLoader):
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
model_path = Path(config.path)
|
||||
assert self._model_base is not None
|
||||
model = PeftModel.from_checkpoint(
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
base_model=self._model_base,
|
||||
|
||||
@@ -6,13 +6,13 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
@@ -12,7 +13,6 @@ from invokeai.backend.model_manager import (
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
|
||||
@@ -5,13 +5,13 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@@ -14,8 +14,7 @@ from invokeai.backend.model_manager import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Optional
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ def skip_torch_weight_init() -> Generator[None, None, None]:
|
||||
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
||||
monkey-patches common torch layers to skip the weight initialization step.
|
||||
"""
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding, torch.nn.LayerNorm]
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
||||
saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules]
|
||||
|
||||
try:
|
||||
|
||||
@@ -13,7 +13,7 @@ from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.any_model_type import AnyModel
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
|
||||
@@ -6,16 +6,17 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
|
||||
from ..raw_model import RawModel
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
|
||||
|
||||
# NOTE FROM LS: This was copied from Stalker's original implementation.
|
||||
# I have not yet gone through and fixed all the type hints
|
||||
class IAIOnnxRuntimeModel(torch.nn.Module):
|
||||
class IAIOnnxRuntimeModel(RawModel):
|
||||
class _tensor_access:
|
||||
def __init__(self, model): # type: ignore
|
||||
self.model = model
|
||||
@@ -102,7 +103,7 @@ class IAIOnnxRuntimeModel(torch.nn.Module):
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=False)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=True)
|
||||
# self.data = dict()
|
||||
# for tensor in self.proto.graph.initializer:
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
import torch
|
||||
from diffusers.utils.state_dict_utils import convert_state_dict
|
||||
|
||||
KOHYA_SS_TO_PEFT = {
|
||||
"lora_down": "lora_A",
|
||||
"lora_up": "lora_B",
|
||||
# This is not a comprehensive dict. See `convert_state_dict_to_peft(...)` for more info on the conversion.
|
||||
}
|
||||
|
||||
|
||||
def convert_state_dict_kohya_to_peft(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
# TODO(ryand): Check that state_dict is in Kohya format.
|
||||
|
||||
peft_partial_state_dict = convert_state_dict(state_dict, KOHYA_SS_TO_PEFT)
|
||||
|
||||
peft_state_dict: dict[str, torch.Tensor] = {}
|
||||
for key, weight in peft_partial_state_dict.items():
|
||||
|
||||
|
||||
for kohya_key, weight in kohya_ss_partial_state_dict.items():
|
||||
if "text_encoder_2." in kohya_key:
|
||||
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
|
||||
elif "text_encoder." in kohya_key:
|
||||
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
|
||||
elif "unet" in kohya_key:
|
||||
kohya_key = kohya_key.replace("unet", "lora_unet")
|
||||
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
|
||||
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
|
||||
kohya_ss_state_dict[kohya_key] = weight
|
||||
if "lora_down" in kohya_key:
|
||||
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
|
||||
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
|
||||
def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
|
||||
r"""
|
||||
Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc.
|
||||
The method only supports the conversion from PEFT to Kohya for now.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
original_type (`StateDictType`, *optional*):
|
||||
The original type of the state dict, if not provided, the method will try to infer it automatically.
|
||||
kwargs (`dict`, *args*):
|
||||
Additional arguments to pass to the method.
|
||||
|
||||
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
|
||||
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
|
||||
`get_peft_model_state_dict` method:
|
||||
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
|
||||
but we add it here in case we don't want to rely on that method.
|
||||
"""
|
||||
|
||||
peft_adapter_name = kwargs.pop("adapter_name", None)
|
||||
if peft_adapter_name is not None:
|
||||
peft_adapter_name = "." + peft_adapter_name
|
||||
else:
|
||||
peft_adapter_name = ""
|
||||
|
||||
if original_type is None:
|
||||
if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.PEFT
|
||||
|
||||
if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys():
|
||||
raise ValueError(f"Original type {original_type} is not supported")
|
||||
|
||||
# Use the convert_state_dict function with the appropriate mapping
|
||||
kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT])
|
||||
kohya_ss_state_dict = {}
|
||||
|
||||
# Additional logic for replacing header, alpha parameters `.` with `_` in all keys
|
||||
for kohya_key, weight in kohya_ss_partial_state_dict.items():
|
||||
if "text_encoder_2." in kohya_key:
|
||||
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
|
||||
elif "text_encoder." in kohya_key:
|
||||
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
|
||||
elif "unet" in kohya_key:
|
||||
kohya_key = kohya_key.replace("unet", "lora_unet")
|
||||
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
|
||||
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
|
||||
kohya_ss_state_dict[kohya_key] = weight
|
||||
if "lora_down" in kohya_key:
|
||||
alpha_key = f'{kohya_key.split(".")[0]}.alpha'
|
||||
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
|
||||
|
||||
return kohya_ss_state_dict
|
||||
@@ -1,52 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers.loaders.lora_conversion_utils import _convert_kohya_lora_to_diffusers
|
||||
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.peft.sdxl_format_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.util.serialization import load_state_dict
|
||||
|
||||
|
||||
class PeftModel:
|
||||
"""A class for loading and managing parameter-efficient fine-tuning models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
network_alphas: dict[str, torch.Tensor],
|
||||
):
|
||||
self.name = name
|
||||
self.state_dict = state_dict
|
||||
self.network_alphas = network_alphas
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for tensor in self.state_dict.values():
|
||||
model_size += tensor.nelement() * tensor.element_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
):
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
file_path = Path(file_path)
|
||||
|
||||
state_dict = load_state_dict(file_path, device=str(device))
|
||||
# lora_unet_up_blocks_1_attentions_2_transformer_blocks_1_ff_net_2.lora_down.weight
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
# TODO(ryand): We shouldn't be using an unexported function from diffusers here. Consider opening an upstream PR
|
||||
# to move this function to state_dict_utils.py.
|
||||
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
|
||||
return cls(name=file_path.stem, state_dict=state_dict, network_alphas=network_alphas)
|
||||
@@ -1,227 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.utils.peft_utils import get_peft_kwargs, scale_lora_layers
|
||||
from diffusers.utils.state_dict_utils import convert_state_dict_to_peft, convert_unet_state_dict_to_peft
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
from invokeai.backend.peft.peft_model import PeftModel
|
||||
|
||||
UNET_NAME = "unet"
|
||||
|
||||
|
||||
class PeftModelPatcher:
|
||||
@classmethod
|
||||
@contextmanager
|
||||
@torch.no_grad()
|
||||
def apply_peft_model_to_text_encoder(
|
||||
cls,
|
||||
text_encoder: torch.nn.Module,
|
||||
peft_models: Iterator[Tuple[PeftModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
original_weights = {}
|
||||
|
||||
try:
|
||||
for peft_model, peft_model_weight in peft_models:
|
||||
keys = list(peft_model.state_dict.keys())
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in peft_model.state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
|
||||
if len(text_encoder_lora_state_dict) == 0:
|
||||
continue
|
||||
|
||||
if peft_model.name in getattr(text_encoder, "peft_config", {}):
|
||||
raise ValueError(f"Adapter name {peft_model.name} already in use in the text encoder ({prefix}).")
|
||||
|
||||
rank = {}
|
||||
# TODO(ryand): Is this necessary?
|
||||
# text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
|
||||
network_alphas = peft_model.network_alphas
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
||||
]
|
||||
network_alphas = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
lora_config_kwargs["inference_mode"] = True
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
new_text_encoder = inject_adapter_in_model(lora_config, text_encoder, peft_model.name)
|
||||
incompatible_keys = set_peft_model_state_dict(
|
||||
new_text_encoder, text_encoder_lora_state_dict, peft_model.name
|
||||
)
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
raise ValueError(f"Failed to inject unexpected PEFT keys: {unexpected_keys}")
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
# text_encoder.load_adapter(
|
||||
# adapter_name=adapter_name,
|
||||
# adapter_state_dict=text_encoder_lora_state_dict,
|
||||
# peft_config=lora_config,
|
||||
# )
|
||||
|
||||
scale_lora_layers(text_encoder, weight=peft_model_weight)
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
yield
|
||||
finally:
|
||||
# TODO
|
||||
pass
|
||||
# for module_key, weight in original_weights.items():
|
||||
# model.get_submodule(module_key).weight.copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
@torch.no_grad()
|
||||
def apply_peft_model_to_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
peft_models: Iterator[Tuple[PeftModel, float]],
|
||||
):
|
||||
try:
|
||||
for peft_model, peft_model_weight in peft_models:
|
||||
keys = list(peft_model.state_dict.keys())
|
||||
|
||||
unet_keys = [k for k in keys if k.startswith(UNET_NAME)]
|
||||
state_dict = {
|
||||
k.replace(f"{UNET_NAME}.", ""): v for k, v in peft_model.state_dict.items() if k in unet_keys
|
||||
}
|
||||
|
||||
network_alphas = peft_model.network_alphas
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(UNET_NAME)]
|
||||
network_alphas = {
|
||||
k.replace(f"{UNET_NAME}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
if len(state_dict) == 0:
|
||||
continue
|
||||
|
||||
if peft_model.name in getattr(unet, "peft_config", {}):
|
||||
raise ValueError(f"Adapter name {peft_model.name} already in use in the Unet.")
|
||||
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
|
||||
if network_alphas is not None:
|
||||
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
|
||||
# `convert_unet_state_dict_to_peft` method.
|
||||
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||
lora_config_kwargs["inference_mode"] = True
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
inject_adapter_in_model(lora_config, unet, adapter_name=peft_model.name)
|
||||
incompatible_keys = set_peft_model_state_dict(unet, state_dict, peft_model.name)
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
raise ValueError(f"Failed to inject unexpected PEFT keys: {unexpected_keys}")
|
||||
|
||||
# TODO(ryand): What does this do?
|
||||
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=True)
|
||||
|
||||
# TODO(ryand): Apply the lora weight. Where does diffusers do this? They don't seem to do it when they
|
||||
# patch the UNet.
|
||||
yield
|
||||
finally:
|
||||
# TODO
|
||||
pass
|
||||
# for module_key, weight in original_weights.items():
|
||||
# model.get_submodule(module_key).weight.copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
@torch.no_grad()
|
||||
def apply_peft_patch(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
peft_models: Iterator[Tuple[PeftModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
original_weights = {}
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
try:
|
||||
for peft_model, peft_model_weight in peft_models:
|
||||
for layer_key, layer in peft_model.state_dict.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
module_key = layer_key.replace(prefix + ".", "")
|
||||
# TODO(ryand): Make this work.
|
||||
|
||||
module = model_state_dict[module_key]
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
# (Performance will be best if this is a CUDA device.)
|
||||
device = module.weight.device
|
||||
dtype = module.weight.dtype
|
||||
|
||||
if module_key not in original_weights:
|
||||
# TODO(ryand): Set non_blocking = True?
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(device=torch.device("cpu"))
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
assert hasattr(layer_weight, "reshape")
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
module.weight += layer_weight.to(dtype=dtype)
|
||||
yield
|
||||
finally:
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(weight)
|
||||
@@ -1,154 +0,0 @@
|
||||
import bisect
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# Code based on:
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map() -> list[tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer: list[tuple[str, str]] = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map: list[tuple[str, str]] = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
# A mapping of state_dict key prefixes from Stability AI SDXL format to diffusers SDXL format.
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
15
invokeai/backend/raw_model.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
|
||||
class RawModel:
|
||||
"""Base class for 'Raw' model wrappers."""
|
||||
@@ -9,8 +9,10 @@ from safetensors.torch import load_file
|
||||
from transformers import CLIPTokenizer
|
||||
from typing_extensions import Self
|
||||
|
||||
from .raw_model import RawModel
|
||||
|
||||
class TextualInversionModelRaw(torch.nn.Module):
|
||||
|
||||
class TextualInversionModelRaw(RawModel):
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def state_dict_to(
|
||||
state_dict: dict[str, torch.Tensor], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
new_state_dict: dict[str, torch.Tensor] = {}
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict[k] = v.to(device=device, dtype=dtype, non_blocking=True)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path: Union[str, Path], device: str = "cpu") -> Any:
|
||||
"""Load a state_dict from a file that may be in either PyTorch or safetensors format. The file format is inferred
|
||||
from the file extension.
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(
|
||||
file_path,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
# weights_only=True is used to address a security vulnerability that allows arbitrary code execution.
|
||||
# This option was first introduced in https://github.com/pytorch/pytorch/pull/86812.
|
||||
#
|
||||
# mmap=True is used to both reduce memory usage and speed up loading. This setting causes torch.load() to more
|
||||
# closely mirror the behaviour of safetensors.torch.load_file(). This option was first introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/102549. The discussion on that PR provides helpful context.
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True, mmap=True)
|
||||
|
||||
return state_dict
|
||||
@@ -75,7 +75,8 @@
|
||||
"copy": "Kopieren",
|
||||
"aboutHeading": "Nutzen Sie Ihre kreative Energie",
|
||||
"toResolve": "Lösen",
|
||||
"add": "Hinzufügen"
|
||||
"add": "Hinzufügen",
|
||||
"loglevel": "Protokoll Stufe"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Bildgröße",
|
||||
@@ -388,7 +389,14 @@
|
||||
"vaePrecision": "VAE-Präzision",
|
||||
"variant": "Variante",
|
||||
"modelDeleteFailed": "Modell konnte nicht gelöscht werden",
|
||||
"noModelSelected": "Kein Modell ausgewählt"
|
||||
"noModelSelected": "Kein Modell ausgewählt",
|
||||
"huggingFace": "HuggingFace",
|
||||
"defaultSettings": "Standardeinstellungen",
|
||||
"edit": "Bearbeiten",
|
||||
"cancel": "Stornieren",
|
||||
"defaultSettingsSaved": "Standardeinstellungen gespeichert",
|
||||
"addModels": "Model hinzufügen",
|
||||
"deleteModelImage": "Lösche Model Bild"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Bilder",
|
||||
@@ -677,7 +685,8 @@
|
||||
"body": "Körper",
|
||||
"hands": "Hände",
|
||||
"dwOpenpose": "DW Openpose",
|
||||
"dwOpenposeDescription": "Posenschätzung mit DW Openpose"
|
||||
"dwOpenposeDescription": "Posenschätzung mit DW Openpose",
|
||||
"selectCLIPVisionModel": "Wähle ein CLIP Vision Model aus"
|
||||
},
|
||||
"queue": {
|
||||
"status": "Status",
|
||||
@@ -765,7 +774,10 @@
|
||||
"recallParameters": "Parameter wiederherstellen",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"allPrompts": "Alle Prompts",
|
||||
"imageDimensions": "Bilder Auslösungen"
|
||||
"imageDimensions": "Bilder Auslösungen",
|
||||
"parameterSet": "Parameter {{parameter}} setzen",
|
||||
"recallParameter": "{{label}} Abrufen",
|
||||
"parsingFailed": "Parsing Fehlgeschlagen"
|
||||
},
|
||||
"popovers": {
|
||||
"noiseUseCPU": {
|
||||
@@ -1030,7 +1042,8 @@
|
||||
"title": "Bild"
|
||||
},
|
||||
"advanced": {
|
||||
"title": "Erweitert"
|
||||
"title": "Erweitert",
|
||||
"options": "$t(accordions.advanced.title) Optionen"
|
||||
},
|
||||
"control": {
|
||||
"title": "Kontrolle"
|
||||
|
||||
@@ -684,6 +684,7 @@
|
||||
"noModelsInstalled": "No Models Installed",
|
||||
"noModelsInstalledDesc1": "Install models with the",
|
||||
"noModelSelected": "No Model Selected",
|
||||
"noMatchingModels": "No matching Models",
|
||||
"none": "none",
|
||||
"path": "Path",
|
||||
"pathToConfig": "Path To Config",
|
||||
@@ -887,6 +888,11 @@
|
||||
"imageFit": "Fit Initial Image To Output Size",
|
||||
"images": "Images",
|
||||
"infillMethod": "Infill Method",
|
||||
"infillMosaicTileWidth": "Tile Width",
|
||||
"infillMosaicTileHeight": "Tile Height",
|
||||
"infillMosaicMinColor": "Min Color",
|
||||
"infillMosaicMaxColor": "Max Color",
|
||||
"infillColorValue": "Fill Color",
|
||||
"info": "Info",
|
||||
"invoke": {
|
||||
"addingImagesTo": "Adding images to",
|
||||
@@ -1417,6 +1423,7 @@
|
||||
"eraseBoundingBox": "Erase Bounding Box",
|
||||
"eraser": "Eraser",
|
||||
"fillBoundingBox": "Fill Bounding Box",
|
||||
"initialFitImageSize": "Fit Image Size on Drop",
|
||||
"invertBrushSizeScrollDirection": "Invert Scroll for Brush Size",
|
||||
"layer": "Layer",
|
||||
"limitStrokesToBox": "Limit Strokes to Box",
|
||||
|
||||
@@ -366,7 +366,7 @@
|
||||
"modelConverted": "Modello convertito",
|
||||
"alpha": "Alpha",
|
||||
"convertToDiffusersHelpText1": "Questo modello verrà convertito nel formato 🧨 Diffusori.",
|
||||
"convertToDiffusersHelpText3": "Il file Checkpoint su disco verrà eliminato se si trova nella cartella principale di InvokeAI. Se si trova invece in una posizione personalizzata, NON verrà eliminato.",
|
||||
"convertToDiffusersHelpText3": "Il file del modello su disco verrà eliminato se si trova nella cartella principale di InvokeAI. Se si trova invece in una posizione personalizzata, NON verrà eliminato.",
|
||||
"v2_base": "v2 (512px)",
|
||||
"v2_768": "v2 (768px)",
|
||||
"none": "nessuno",
|
||||
@@ -443,7 +443,8 @@
|
||||
"noModelsInstalled": "Nessun modello installato",
|
||||
"hfTokenInvalidErrorMessage2": "Aggiornalo in ",
|
||||
"main": "Principali",
|
||||
"noModelsInstalledDesc1": "Installa i modelli con"
|
||||
"noModelsInstalledDesc1": "Installa i modelli con",
|
||||
"ipAdapters": "Adattatori IP"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@@ -937,7 +938,8 @@
|
||||
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
|
||||
"mediapipeFace": "Mediapipe Volto",
|
||||
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
|
||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))"
|
||||
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
|
||||
"selectCLIPVisionModel": "Seleziona un modello CLIP Vision"
|
||||
},
|
||||
"queue": {
|
||||
"queueFront": "Aggiungi all'inizio della coda",
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
setShouldAutoSave,
|
||||
setShouldCropToBoundingBoxOnSave,
|
||||
setShouldDarkenOutsideBoundingBox,
|
||||
setShouldFitImageSize,
|
||||
setShouldInvertBrushSizeScrollDirection,
|
||||
setShouldRestrictStrokesToBox,
|
||||
setShouldShowCanvasDebugInfo,
|
||||
@@ -48,6 +49,7 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
const shouldSnapToGrid = useAppSelector((s) => s.canvas.shouldSnapToGrid);
|
||||
const shouldRestrictStrokesToBox = useAppSelector((s) => s.canvas.shouldRestrictStrokesToBox);
|
||||
const shouldAntialias = useAppSelector((s) => s.canvas.shouldAntialias);
|
||||
const shouldFitImageSize = useAppSelector((s) => s.canvas.shouldFitImageSize);
|
||||
|
||||
useHotkeys(
|
||||
['n'],
|
||||
@@ -102,6 +104,10 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldAntialias(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
const handleChangeShouldFitImageSize = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => dispatch(setShouldFitImageSize(e.target.checked)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
@@ -165,6 +171,10 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
<FormLabel>{t('unifiedCanvas.antialiasing')}</FormLabel>
|
||||
<Checkbox isChecked={shouldAntialias} onChange={handleChangeShouldAntialias} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('unifiedCanvas.initialFitImageSize')}</FormLabel>
|
||||
<Checkbox isChecked={shouldFitImageSize} onChange={handleChangeShouldFitImageSize} />
|
||||
</FormControl>
|
||||
</FormControlGroup>
|
||||
<ClearCanvasHistoryButtonModal />
|
||||
</Flex>
|
||||
|
||||
@@ -66,6 +66,7 @@ const initialCanvasState: CanvasState = {
|
||||
shouldAutoSave: false,
|
||||
shouldCropToBoundingBoxOnSave: false,
|
||||
shouldDarkenOutsideBoundingBox: false,
|
||||
shouldFitImageSize: true,
|
||||
shouldInvertBrushSizeScrollDirection: false,
|
||||
shouldLockBoundingBox: false,
|
||||
shouldPreserveMaskedArea: false,
|
||||
@@ -144,12 +145,20 @@ export const canvasSlice = createSlice({
|
||||
reducer: (state, action: PayloadActionWithOptimalDimension<ImageDTO>) => {
|
||||
const { width, height, image_name } = action.payload;
|
||||
const { optimalDimension } = action.meta;
|
||||
const { stageDimensions } = state;
|
||||
const { stageDimensions, shouldFitImageSize } = state;
|
||||
|
||||
const newBoundingBoxDimensions = {
|
||||
width: roundDownToMultiple(clamp(width, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
||||
height: roundDownToMultiple(clamp(height, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
||||
};
|
||||
const newBoundingBoxDimensions = shouldFitImageSize
|
||||
? {
|
||||
width: roundDownToMultiple(width, CANVAS_GRID_SIZE_FINE),
|
||||
height: roundDownToMultiple(height, CANVAS_GRID_SIZE_FINE),
|
||||
}
|
||||
: {
|
||||
width: roundDownToMultiple(clamp(width, CANVAS_GRID_SIZE_FINE, optimalDimension), CANVAS_GRID_SIZE_FINE),
|
||||
height: roundDownToMultiple(
|
||||
clamp(height, CANVAS_GRID_SIZE_FINE, optimalDimension),
|
||||
CANVAS_GRID_SIZE_FINE
|
||||
),
|
||||
};
|
||||
|
||||
const newBoundingBoxCoordinates = {
|
||||
x: roundToMultiple(width / 2 - newBoundingBoxDimensions.width / 2, CANVAS_GRID_SIZE_FINE),
|
||||
@@ -582,6 +591,9 @@ export const canvasSlice = createSlice({
|
||||
setShouldAntialias: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldAntialias = action.payload;
|
||||
},
|
||||
setShouldFitImageSize: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldFitImageSize = action.payload;
|
||||
},
|
||||
setShouldCropToBoundingBoxOnSave: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldCropToBoundingBoxOnSave = action.payload;
|
||||
},
|
||||
@@ -692,6 +704,7 @@ export const {
|
||||
setShouldRestrictStrokesToBox,
|
||||
stagingAreaInitialized,
|
||||
setShouldAntialias,
|
||||
setShouldFitImageSize,
|
||||
canvasResized,
|
||||
canvasBatchIdAdded,
|
||||
canvasBatchIdsReset,
|
||||
|
||||
@@ -120,6 +120,7 @@ export interface CanvasState {
|
||||
shouldAutoSave: boolean;
|
||||
shouldCropToBoundingBoxOnSave: boolean;
|
||||
shouldDarkenOutsideBoundingBox: boolean;
|
||||
shouldFitImageSize: boolean;
|
||||
shouldInvertBrushSizeScrollDirection: boolean;
|
||||
shouldLockBoundingBox: boolean;
|
||||
shouldPreserveMaskedArea: boolean;
|
||||
|
||||
@@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
||||
|
||||
type ModelManagerState = {
|
||||
_version: 1;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@@ -9,10 +10,11 @@ import {
|
||||
useIPAdapterModels,
|
||||
useLoRAModels,
|
||||
useMainModels,
|
||||
useRefinerModels,
|
||||
useT2IAdapterModels,
|
||||
useVAEModels,
|
||||
} from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig, ModelType } from 'services/api/types';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import { FetchingModelsLoader } from './FetchingModelsLoader';
|
||||
import { ModelListWrapper } from './ModelListWrapper';
|
||||
@@ -27,6 +29,12 @@ const ModelList = () => {
|
||||
[mainModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [refinerModels, { isLoading: isLoadingRefinerModels }] = useRefinerModels();
|
||||
const filteredRefinerModels = useMemo(
|
||||
() => modelsFilter(refinerModels, searchTerm, filteredModelType),
|
||||
[refinerModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
|
||||
const filteredLoRAModels = useMemo(
|
||||
() => modelsFilter(loraModels, searchTerm, filteredModelType),
|
||||
@@ -63,6 +71,28 @@ const ModelList = () => {
|
||||
[vaeModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const totalFilteredModels = useMemo(() => {
|
||||
return (
|
||||
filteredMainModels.length +
|
||||
filteredRefinerModels.length +
|
||||
filteredLoRAModels.length +
|
||||
filteredEmbeddingModels.length +
|
||||
filteredControlNetModels.length +
|
||||
filteredT2IAdapterModels.length +
|
||||
filteredIPAdapterModels.length +
|
||||
filteredVAEModels.length
|
||||
);
|
||||
}, [
|
||||
filteredControlNetModels.length,
|
||||
filteredEmbeddingModels.length,
|
||||
filteredIPAdapterModels.length,
|
||||
filteredLoRAModels.length,
|
||||
filteredMainModels.length,
|
||||
filteredRefinerModels.length,
|
||||
filteredT2IAdapterModels.length,
|
||||
filteredVAEModels.length,
|
||||
]);
|
||||
|
||||
return (
|
||||
<ScrollableContent>
|
||||
<Flex flexDirection="column" w="full" h="full" gap={4}>
|
||||
@@ -71,6 +101,11 @@ const ModelList = () => {
|
||||
{!isLoadingMainModels && filteredMainModels.length > 0 && (
|
||||
<ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" />
|
||||
)}
|
||||
{/* Refiner Model List */}
|
||||
{isLoadingRefinerModels && <FetchingModelsLoader loadingMessage="Loading Refiner Models..." />}
|
||||
{!isLoadingRefinerModels && filteredRefinerModels.length > 0 && (
|
||||
<ModelListWrapper title={t('sdxl.refiner')} modelList={filteredRefinerModels} key="refiner" />
|
||||
)}
|
||||
{/* LoRAs List */}
|
||||
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
|
||||
@@ -108,6 +143,11 @@ const ModelList = () => {
|
||||
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||
)}
|
||||
{totalFilteredModels === 0 && (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Text>{t('modelManager.noMatchingModels')}</Text>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
);
|
||||
@@ -118,12 +158,24 @@ export default memo(ModelList);
|
||||
const modelsFilter = <T extends AnyModelConfig>(
|
||||
data: T[],
|
||||
nameFilter: string,
|
||||
filteredModelType: ModelType | null
|
||||
filteredModelType: FilterableModelType | null
|
||||
): T[] => {
|
||||
return data.filter((model) => {
|
||||
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
||||
const matchesType = filteredModelType ? model.type === filteredModelType : true;
|
||||
const matchesType = getMatchesType(model, filteredModelType);
|
||||
|
||||
return matchesFilter && matchesType;
|
||||
});
|
||||
};
|
||||
|
||||
const getMatchesType = (modelConfig: AnyModelConfig, filteredModelType: FilterableModelType | null): boolean => {
|
||||
if (filteredModelType === 'refiner') {
|
||||
return modelConfig.base === 'sdxl-refiner';
|
||||
}
|
||||
|
||||
if (filteredModelType === 'main' && modelConfig.base === 'sdxl-refiner') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return filteredModelType ? modelConfig.type === filteredModelType : true;
|
||||
};
|
||||
|
||||
@@ -13,6 +13,7 @@ export const ModelTypeFilter = () => {
|
||||
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
|
||||
() => ({
|
||||
main: t('modelManager.main'),
|
||||
refiner: t('sdxl.refiner'),
|
||||
lora: 'LoRA',
|
||||
embedding: t('modelManager.textualInversions'),
|
||||
controlnet: 'ControlNet',
|
||||
|
||||
@@ -65,6 +65,11 @@ export const buildCanvasOutpaintGraph = async (
|
||||
infillTileSize,
|
||||
infillPatchmatchDownscaleSize,
|
||||
infillMethod,
|
||||
// infillMosaicTileWidth,
|
||||
// infillMosaicTileHeight,
|
||||
// infillMosaicMinColor,
|
||||
// infillMosaicMaxColor,
|
||||
infillColorValue,
|
||||
clipSkip,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
@@ -356,6 +361,28 @@ export const buildCanvasOutpaintGraph = async (
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: add mosaic back
|
||||
// if (infillMethod === 'mosaic') {
|
||||
// graph.nodes[INPAINT_INFILL] = {
|
||||
// type: 'infill_mosaic',
|
||||
// id: INPAINT_INFILL,
|
||||
// is_intermediate,
|
||||
// tile_width: infillMosaicTileWidth,
|
||||
// tile_height: infillMosaicTileHeight,
|
||||
// min_color: infillMosaicMinColor,
|
||||
// max_color: infillMosaicMaxColor,
|
||||
// };
|
||||
// }
|
||||
|
||||
if (infillMethod === 'color') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_rgba',
|
||||
id: INPAINT_INFILL,
|
||||
color: infillColorValue,
|
||||
is_intermediate,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Scale Before Processing
|
||||
if (isUsingScaledDimensions) {
|
||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||
|
||||
@@ -66,6 +66,11 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
infillTileSize,
|
||||
infillPatchmatchDownscaleSize,
|
||||
infillMethod,
|
||||
// infillMosaicTileWidth,
|
||||
// infillMosaicTileHeight,
|
||||
// infillMosaicMinColor,
|
||||
// infillMosaicMaxColor,
|
||||
infillColorValue,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
canvasCoherenceMode,
|
||||
@@ -365,6 +370,28 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: add mosaic back
|
||||
// if (infillMethod === 'mosaic') {
|
||||
// graph.nodes[INPAINT_INFILL] = {
|
||||
// type: 'infill_mosaic',
|
||||
// id: INPAINT_INFILL,
|
||||
// is_intermediate,
|
||||
// tile_width: infillMosaicTileWidth,
|
||||
// tile_height: infillMosaicTileHeight,
|
||||
// min_color: infillMosaicMinColor,
|
||||
// max_color: infillMosaicMaxColor,
|
||||
// };
|
||||
// }
|
||||
|
||||
if (infillMethod === 'color') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_rgba',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate,
|
||||
color: infillColorValue,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Scale Before Processing
|
||||
if (isUsingScaledDimensions) {
|
||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamInfillColorOptions = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectGenerationSlice, (generation) => ({
|
||||
infillColor: generation.infillColorValue,
|
||||
})),
|
||||
[]
|
||||
);
|
||||
|
||||
const { infillColor } = useAppSelector(selector);
|
||||
|
||||
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleInfillColor = useCallback(
|
||||
(v: RgbaColor) => {
|
||||
dispatch(setInfillColorValue(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<FormControl isDisabled={infillMethod !== 'color'}>
|
||||
<FormLabel>{t('parameters.infillColorValue')}</FormLabel>
|
||||
<Box w="full" pt={2} pb={2}>
|
||||
<IAIColorPicker color={infillColor} onChange={handleInfillColor} />
|
||||
</Box>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamInfillColorOptions);
|
||||
@@ -0,0 +1,127 @@
|
||||
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIColorPicker from 'common/components/IAIColorPicker';
|
||||
import {
|
||||
selectGenerationSlice,
|
||||
setInfillMosaicMaxColor,
|
||||
setInfillMosaicMinColor,
|
||||
setInfillMosaicTileHeight,
|
||||
setInfillMosaicTileWidth,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const ParamInfillMosaicTileSize = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectGenerationSlice, (generation) => ({
|
||||
infillMosaicTileWidth: generation.infillMosaicTileWidth,
|
||||
infillMosaicTileHeight: generation.infillMosaicTileHeight,
|
||||
infillMosaicMinColor: generation.infillMosaicMinColor,
|
||||
infillMosaicMaxColor: generation.infillMosaicMaxColor,
|
||||
})),
|
||||
[]
|
||||
);
|
||||
|
||||
const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleInfillMosaicTileWidthChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(setInfillMosaicTileWidth(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleInfillMosaicTileHeightChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(setInfillMosaicTileHeight(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleInfillMosaicMinColor = useCallback(
|
||||
(v: RgbaColor) => {
|
||||
dispatch(setInfillMosaicMinColor(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleInfillMosaicMaxColor = useCallback(
|
||||
(v: RgbaColor) => {
|
||||
dispatch(setInfillMosaicMaxColor(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||
<FormLabel>{t('parameters.infillMosaicTileWidth')}</FormLabel>
|
||||
<CompositeSlider
|
||||
min={8}
|
||||
max={256}
|
||||
value={infillMosaicTileWidth}
|
||||
defaultValue={64}
|
||||
onChange={handleInfillMosaicTileWidthChange}
|
||||
step={8}
|
||||
fineStep={8}
|
||||
marks
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
min={8}
|
||||
max={1024}
|
||||
value={infillMosaicTileWidth}
|
||||
defaultValue={64}
|
||||
onChange={handleInfillMosaicTileWidthChange}
|
||||
step={8}
|
||||
fineStep={8}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||
<FormLabel>{t('parameters.infillMosaicTileHeight')}</FormLabel>
|
||||
<CompositeSlider
|
||||
min={8}
|
||||
max={256}
|
||||
value={infillMosaicTileHeight}
|
||||
defaultValue={64}
|
||||
onChange={handleInfillMosaicTileHeightChange}
|
||||
step={8}
|
||||
fineStep={8}
|
||||
marks
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
min={8}
|
||||
max={1024}
|
||||
value={infillMosaicTileHeight}
|
||||
defaultValue={64}
|
||||
onChange={handleInfillMosaicTileHeightChange}
|
||||
step={8}
|
||||
fineStep={8}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||
<FormLabel>{t('parameters.infillMosaicMinColor')}</FormLabel>
|
||||
<Box w="full" pt={2} pb={2}>
|
||||
<IAIColorPicker color={infillMosaicMinColor} onChange={handleInfillMosaicMinColor} />
|
||||
</Box>
|
||||
</FormControl>
|
||||
<FormControl isDisabled={infillMethod !== 'mosaic'}>
|
||||
<FormLabel>{t('parameters.infillMosaicMaxColor')}</FormLabel>
|
||||
<Box w="full" pt={2} pb={2}>
|
||||
<IAIColorPicker color={infillMosaicMaxColor} onChange={handleInfillMosaicMaxColor} />
|
||||
</Box>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamInfillMosaicTileSize);
|
||||
@@ -1,6 +1,8 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
|
||||
import ParamInfillColorOptions from './ParamInfillColorOptions';
|
||||
import ParamInfillMosaicOptions from './ParamInfillMosaicOptions';
|
||||
import ParamInfillPatchmatchDownscaleSize from './ParamInfillPatchmatchDownscaleSize';
|
||||
import ParamInfillTilesize from './ParamInfillTilesize';
|
||||
|
||||
@@ -14,6 +16,14 @@ const ParamInfillOptions = () => {
|
||||
return <ParamInfillPatchmatchDownscaleSize />;
|
||||
}
|
||||
|
||||
if (infillMethod === 'mosaic') {
|
||||
return <ParamInfillMosaicOptions />;
|
||||
}
|
||||
|
||||
if (infillMethod === 'color') {
|
||||
return <ParamInfillColorOptions />;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import type {
|
||||
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { clamp } from 'lodash-es';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import type { GenerationState } from './types';
|
||||
@@ -43,8 +44,6 @@ const initialGenerationState: GenerationState = {
|
||||
shouldFitToWidthHeight: true,
|
||||
shouldRandomizeSeed: true,
|
||||
steps: 50,
|
||||
infillTileSize: 32,
|
||||
infillPatchmatchDownscaleSize: 1,
|
||||
width: 512,
|
||||
model: null,
|
||||
vae: null,
|
||||
@@ -55,6 +54,13 @@ const initialGenerationState: GenerationState = {
|
||||
shouldUseCpuNoise: true,
|
||||
shouldShowAdvancedOptions: false,
|
||||
aspectRatio: { ...initialAspectRatioState },
|
||||
infillTileSize: 32,
|
||||
infillPatchmatchDownscaleSize: 1,
|
||||
infillMosaicTileWidth: 64,
|
||||
infillMosaicTileHeight: 64,
|
||||
infillMosaicMinColor: { r: 0, g: 0, b: 0, a: 1 },
|
||||
infillMosaicMaxColor: { r: 255, g: 255, b: 255, a: 1 },
|
||||
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
|
||||
};
|
||||
|
||||
export const generationSlice = createSlice({
|
||||
@@ -116,15 +122,6 @@ export const generationSlice = createSlice({
|
||||
setCanvasCoherenceMinDenoise: (state, action: PayloadAction<number>) => {
|
||||
state.canvasCoherenceMinDenoise = action.payload;
|
||||
},
|
||||
setInfillMethod: (state, action: PayloadAction<string>) => {
|
||||
state.infillMethod = action.payload;
|
||||
},
|
||||
setInfillTileSize: (state, action: PayloadAction<number>) => {
|
||||
state.infillTileSize = action.payload;
|
||||
},
|
||||
setInfillPatchmatchDownscaleSize: (state, action: PayloadAction<number>) => {
|
||||
state.infillPatchmatchDownscaleSize = action.payload;
|
||||
},
|
||||
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
|
||||
const { image_name, width, height } = action.payload;
|
||||
state.initialImage = { imageName: image_name, width, height };
|
||||
@@ -206,6 +203,30 @@ export const generationSlice = createSlice({
|
||||
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
|
||||
state.aspectRatio = action.payload;
|
||||
},
|
||||
setInfillMethod: (state, action: PayloadAction<string>) => {
|
||||
state.infillMethod = action.payload;
|
||||
},
|
||||
setInfillTileSize: (state, action: PayloadAction<number>) => {
|
||||
state.infillTileSize = action.payload;
|
||||
},
|
||||
setInfillPatchmatchDownscaleSize: (state, action: PayloadAction<number>) => {
|
||||
state.infillPatchmatchDownscaleSize = action.payload;
|
||||
},
|
||||
setInfillMosaicTileWidth: (state, action: PayloadAction<number>) => {
|
||||
state.infillMosaicTileWidth = action.payload;
|
||||
},
|
||||
setInfillMosaicTileHeight: (state, action: PayloadAction<number>) => {
|
||||
state.infillMosaicTileHeight = action.payload;
|
||||
},
|
||||
setInfillMosaicMinColor: (state, action: PayloadAction<RgbaColor>) => {
|
||||
state.infillMosaicMinColor = action.payload;
|
||||
},
|
||||
setInfillMosaicMaxColor: (state, action: PayloadAction<RgbaColor>) => {
|
||||
state.infillMosaicMaxColor = action.payload;
|
||||
},
|
||||
setInfillColorValue: (state, action: PayloadAction<RgbaColor>) => {
|
||||
state.infillColorValue = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(configChanged, (state, action) => {
|
||||
@@ -249,8 +270,6 @@ export const {
|
||||
setShouldFitToWidthHeight,
|
||||
setShouldRandomizeSeed,
|
||||
setSteps,
|
||||
setInfillTileSize,
|
||||
setInfillPatchmatchDownscaleSize,
|
||||
initialImageChanged,
|
||||
modelChanged,
|
||||
vaeSelected,
|
||||
@@ -264,6 +283,13 @@ export const {
|
||||
heightChanged,
|
||||
widthRecalled,
|
||||
heightRecalled,
|
||||
setInfillTileSize,
|
||||
setInfillPatchmatchDownscaleSize,
|
||||
setInfillMosaicTileWidth,
|
||||
setInfillMosaicTileHeight,
|
||||
setInfillMosaicMinColor,
|
||||
setInfillMosaicMaxColor,
|
||||
setInfillColorValue,
|
||||
} = generationSlice.actions;
|
||||
|
||||
export const { selectOptimalDimension } = generationSlice.selectors;
|
||||
|
||||
@@ -17,6 +17,7 @@ import type {
|
||||
ParameterVAEModel,
|
||||
ParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import type { RgbaColor } from 'react-colorful';
|
||||
|
||||
export interface GenerationState {
|
||||
_version: 2;
|
||||
@@ -39,8 +40,6 @@ export interface GenerationState {
|
||||
shouldFitToWidthHeight: boolean;
|
||||
shouldRandomizeSeed: boolean;
|
||||
steps: ParameterSteps;
|
||||
infillTileSize: number;
|
||||
infillPatchmatchDownscaleSize: number;
|
||||
width: ParameterWidth;
|
||||
model: ParameterModel | null;
|
||||
vae: ParameterVAEModel | null;
|
||||
@@ -51,6 +50,13 @@ export interface GenerationState {
|
||||
shouldUseCpuNoise: boolean;
|
||||
shouldShowAdvancedOptions: boolean;
|
||||
aspectRatio: AspectRatioState;
|
||||
infillTileSize: number;
|
||||
infillPatchmatchDownscaleSize: number;
|
||||
infillMosaicTileWidth: number;
|
||||
infillMosaicTileHeight: number;
|
||||
infillMosaicMinColor: RgbaColor;
|
||||
infillMosaicMaxColor: RgbaColor;
|
||||
infillColorValue: RgbaColor;
|
||||
}
|
||||
|
||||
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
|
||||
import type { InvokeTabName } from './tabMap';
|
||||
@@ -45,6 +46,9 @@ export const uiSlice = createSlice({
|
||||
builder.addCase(initialImageChanged, (state) => {
|
||||
state.activeTab = 'img2img';
|
||||
});
|
||||
builder.addCase(workflowLoadRequested, (state) => {
|
||||
state.activeTab = 'nodes';
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.0.2"
|
||||
__version__ = "4.0.3"
|
||||
|
||||