Compare commits

..

2 Commits

Author SHA1 Message Date
Ryan Dick
9b763b9e4c Fix issue with seamless context managers when seamless is not configured. 2024-01-05 10:31:58 -05:00
Sergey Borisov
7f3be627c2 Add more seamless configuration options. 2024-01-05 09:57:28 -05:00
18 changed files with 114 additions and 101 deletions

View File

@@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import contextlib
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import List, Literal, Optional, Union
@@ -716,10 +717,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
**self.unet.unet.model_dump(),
context=context,
)
# Prepare seamless context, if configured.
seamless_context = contextlib.nullcontext()
seamless_config = self.unet.seamless
if seamless_config is not None:
seamless_context = set_seamless(
model=unet_info.context.model,
axes=seamless_config.axes,
skipped_layers=seamless_config.skipped_layers,
skip_second_resnet=seamless_config.skip_second_resnet,
skip_conv2=seamless_config.skip_conv2,
)
with (
ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
set_seamless(unet_info.context.model, self.unet.seamless_axes),
seamless_context,
unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
@@ -826,7 +840,19 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
context=context,
)
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
# Prepare seamless context, if configured.
seamless_context = contextlib.nullcontext()
seamless_config = self.vae.seamless
if seamless_config is not None:
seamless_context = set_seamless(
model=vae_info.context.model,
axes=seamless_config.axes,
skipped_layers=seamless_config.skipped_layers,
skip_second_resnet=seamless_config.skip_second_resnet,
skip_conv2=seamless_config.skip_conv2,
)
with seamless_context, vae_info as vae:
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)

View File

@@ -19,6 +19,13 @@ from .baseinvocation import (
)
class SeamlessSettings(BaseModel):
axes: List[str] = Field(description="Axes('x' and 'y') to which apply seamless")
skipped_layers: int = Field(description="How much down layers skip when applying seamless")
skip_second_resnet: bool = Field(description="Skip or not second resnet in down blocks when applying seamless")
skip_conv2: bool = Field(description="Skip or not conv2 in down blocks when applying seamless")
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
@@ -36,8 +43,8 @@ class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
class ClipField(BaseModel):
@@ -50,7 +57,7 @@ class ClipField(BaseModel):
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
@invocation_output("unet_output")
@@ -451,6 +458,11 @@ class SeamlessModeInvocation(BaseInvocation):
)
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
skipped_layers: int = InputField(default=0, input=Input.Any, description="How much model's down layers to skip")
skip_second_resnet: bool = InputField(
default=True, input=Input.Any, description="Skip or not second resnet in down layers"
)
skip_conv2: bool = InputField(default=True, input=Input.Any, description="Skip or not conv2 in down layers")
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
@@ -465,9 +477,19 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_axes_list.append("y")
if unet is not None:
unet.seamless_axes = seamless_axes_list
unet.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
if vae is not None:
vae.seamless_axes = seamless_axes_list
vae.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
return SeamlessModeOutput(unet=unet, vae=vae)

View File

@@ -13,7 +13,6 @@ from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
from .models.lora import LoRAModel
@@ -212,12 +211,8 @@ class ModelPatcher:
for i in range(ti_embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# Modify text_encoder.
# resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of
# this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some
# time.
with skip_torch_weight_init():
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:

View File

@@ -25,71 +25,55 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
def set_seamless(
model: Union[UNet2DConditionModel, AutoencoderKL],
axes: List[str],
skipped_layers: int,
skip_second_resnet: bool,
skip_conv2: bool,
):
try:
to_restore = []
for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel):
if ".attentions." in m_name:
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
continue
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
# down_blocks.1.resnets.1.conv1
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
block_num = int(block_num)
resnet_num = int(resnet_num)
# if block_num >= seamless_down_blocks:
if block_num >= len(model.down_blocks) - skipped_layers:
continue
if ".resnets." in m_name:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
if resnet_num > 0 and skip_second_resnet:
continue
if False and ".downsamplers." in m_name:
if submodule_name == "conv2" and skip_conv2:
continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield

View File

@@ -1,3 +0,0 @@
<svg width="44" height="44" viewBox="0 0 44 44" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M29.1951 10.6667H42V2H2V10.6667H14.8049L29.1951 33.3333H42V42H2V33.3333H14.8049" stroke="#E6FD13" stroke-width="2.8"/>
</svg>

Before

Width:  |  Height:  |  Size: 231 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 116 KiB

View File

@@ -8,8 +8,8 @@
<meta http-equiv="Pragma" content="no-cache">
<meta http-equiv="Expires" content="0">
<title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="mask-icon" type="icon" href="favicon-outline.svg" color="#E6FD13" sizes="any" />
<link rel="icon" type="icon" href="favicon-key.svg" />
<link rel="mask-icon" href="/invoke-key-ylw-sm.svg" color="#E6FD13" sizes="any" />
<link rel="icon" href="/invoke-key-char-on-ylw.svg" />
<style>
html,
body {

View File

Before

Width:  |  Height:  |  Size: 272 B

After

Width:  |  Height:  |  Size: 272 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 336 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

View File

@@ -45,7 +45,6 @@ export const InvControl = memo(
orientation={orientation}
isDisabled={isDisabled}
{...formControlProps}
{...ctx.controlProps}
>
<Flex className="invcontrol-label-wrapper">
{label && (

View File

@@ -1,10 +1,9 @@
import type { FormControlProps, FormLabelProps } from '@chakra-ui/react';
import type { FormLabelProps } from '@chakra-ui/react';
import type { PropsWithChildren } from 'react';
import { createContext, memo } from 'react';
export type InvControlGroupProps = {
labelProps?: FormLabelProps;
controlProps?: FormControlProps;
isDisabled?: boolean;
orientation?: 'horizontal' | 'vertical';
};

View File

@@ -1,6 +1,6 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import InvokeLogoSVG from 'assets/images/invoke-key-wht-lrg.svg';
import InvokeAILogoImage from 'assets/images/logo.png';
import IAIDroppable from 'common/components/IAIDroppable';
import { InvText } from 'common/components/InvText/wrapper';
import { InvTooltip } from 'common/components/InvTooltip/InvTooltip';
@@ -101,10 +101,10 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
alignItems="center"
>
<Image
src={InvokeLogoSVG}
src={InvokeAILogoImage}
alt="invoke-ai-logo"
opacity={0.7}
mixBlendMode="overlay"
opacity={0.4}
filter="grayscale(1)"
mt={-6}
w={16}
h={16}

View File

@@ -4,10 +4,10 @@ import {
InvCardBody,
InvCardHeader,
} from 'common/components/InvCard/wrapper';
import { InvLabel } from 'common/components/InvControl/InvLabel';
import { InvIconButton } from 'common/components/InvIconButton/InvIconButton';
import { InvNumberInput } from 'common/components/InvNumberInput/InvNumberInput';
import { InvSlider } from 'common/components/InvSlider/InvSlider';
import { InvText } from 'common/components/InvText/wrapper';
import type { LoRA } from 'features/lora/store/loraSlice';
import { loraRemoved, loraWeightChanged } from 'features/lora/store/loraSlice';
import { memo, useCallback } from 'react';
@@ -35,9 +35,9 @@ export const LoRACard = memo((props: LoRACardProps) => {
return (
<InvCard variant="lora">
<InvCardHeader>
<InvLabel noOfLines={1} wordBreak="break-all">
<InvText noOfLines={1} wordBreak="break-all">
{lora.model_name}
</InvLabel>
</InvText>
<InvIconButton
aria-label="Remove LoRA"
variant="ghost"

View File

@@ -1,6 +1,4 @@
import type { ChakraProps } from '@chakra-ui/react';
import { Flex } from '@chakra-ui/react';
import { InvControlGroup } from 'common/components/InvControl/InvControlGroup';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@@ -13,8 +11,6 @@ type Props = {
nodeId: string;
};
const props: ChakraProps = { w: 'unset' };
const InvocationNodeFooter = ({ nodeId }: Props) => {
const hasImageOutput = useHasImageOutput(nodeId);
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
@@ -24,16 +20,13 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
layerStyle="nodeFooter"
w="full"
borderBottomRadius="base"
gap={4}
px={2}
py={0}
h={8}
justifyContent="space-between"
>
<InvControlGroup controlProps={props} labelProps={props}>
{isCacheEnabled && <UseCacheCheckbox nodeId={nodeId} />}
{hasImageOutput && <SaveToGalleryCheckbox nodeId={nodeId} />}
</InvControlGroup>
{isCacheEnabled && <UseCacheCheckbox nodeId={nodeId} />}
{hasImageOutput && <SaveToGalleryCheckbox nodeId={nodeId} />}
</Flex>
);
};

View File

@@ -37,10 +37,11 @@ import { SettingsLanguageSelect } from './SettingsLanguageSelect';
import { SettingsLogLevelSelect } from './SettingsLogLevelSelect';
type ConfigOptions = {
shouldShowDeveloperSettings?: boolean;
shouldShowResetWebUiText?: boolean;
shouldShowClearIntermediates?: boolean;
shouldShowLocalizationToggle?: boolean;
shouldShowDeveloperSettings: boolean;
shouldShowResetWebUiText: boolean;
shouldShowAdvancedOptionsSettings: boolean;
shouldShowClearIntermediates: boolean;
shouldShowLocalizationToggle: boolean;
};
type SettingsModalProps = {
@@ -83,7 +84,7 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
hasPendingItems,
intermediatesCount,
isLoading: isLoadingClearIntermediates,
} = useClearIntermediates(shouldShowClearIntermediates);
} = useClearIntermediates();
const {
isOpen: isSettingsModalOpen,

View File

@@ -17,9 +17,7 @@ export type UseClearIntermediatesReturn = {
hasPendingItems: boolean;
};
export const useClearIntermediates = (
shouldShowClearIntermediates: boolean
): UseClearIntermediatesReturn => {
export const useClearIntermediates = (): UseClearIntermediatesReturn => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
@@ -27,7 +25,6 @@ export const useClearIntermediates = (
undefined,
{
refetchOnMountOrArgChange: true,
skip: !shouldShowClearIntermediates,
}
);

View File

@@ -1 +1 @@
__version__ = "3.6.0rc4"
__version__ = "3.6.0rc3"