Compare commits

...

5 Commits

Author SHA1 Message Date
psychedelicious
a8957aa50d chore: bump version to v5.0.2 2024-10-02 09:35:07 +10:00
Ryan Dick
807f458f13 Move FLUX_LORA_TRANSFORMER_PREFIX and FLUX_LORA_CLIP_PREFIX to a shared location. 2024-10-01 10:22:11 -04:00
Ryan Dick
68dbe45315 Fix regression with FLUX diffusers LoRA models where lora keys were not given the expected prefix. 2024-10-01 10:22:11 -04:00
psychedelicious
bd3d1dcdf9 feat(ui): hide model settings if there isn't any content
For example, CLIP Vision models have no settings.
2024-09-30 22:10:14 -04:00
psychedelicious
386c01ede1 feat(ui): show CLIP Vision models in model manager UI
Not sure why they were hidden but it makes it hard to delete them if they are borked for some reason (have to go thru API docs page or do DB surgery).
2024-09-30 22:10:14 -04:00
15 changed files with 82 additions and 36 deletions

View File

@@ -30,7 +30,7 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
@@ -209,7 +209,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
LoRAPatcher.apply_lora_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
cached_weights=cached_weights,
)
)
@@ -220,7 +220,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
LoRAPatcher.apply_lora_sidecar_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
)
)

View File

@@ -10,7 +10,7 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_CLIP_PREFIX
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
@@ -101,7 +101,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
LoRAPatcher.apply_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context),
prefix=FLUX_KOHYA_CLIP_PREFIX,
prefix=FLUX_LORA_CLIP_PREFIX,
cached_weights=cached_weights,
)
)

View File

@@ -2,6 +2,7 @@ from typing import Dict
import torch
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
@@ -189,7 +190,9 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
# Assert that all keys were processed.
assert len(grouped_state_dict) == 0
return LoRAModelRaw(layers=layers)
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
return LoRAModelRaw(layers=layers_with_prefix)
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, TypeVar
import torch
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
@@ -23,11 +24,6 @@ FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the InvokeAI LoRA format.
FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-"
FLUX_KOHYA_CLIP_PREFIX = "lora_clip-"
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
@@ -67,9 +63,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
# Create LoRA layers.
layers: dict[str, AnyLoRALayer] = {}
for layer_key, layer_state_dict in transformer_grouped_sd.items():
layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for layer_key, layer_state_dict in clip_grouped_sd.items():
layers[FLUX_KOHYA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw.
return LoRAModelRaw(layers=layers)

View File

@@ -0,0 +1,3 @@
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the FLUX InvokeAI LoRA format.
FLUX_LORA_TRANSFORMER_PREFIX = "lora_transformer-"
FLUX_LORA_CLIP_PREFIX = "lora_clip-"

View File

@@ -3,7 +3,7 @@ import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { ModelType } from 'services/api/types';
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
export type FilterableModelType = Exclude<ModelType, 'onnx'> | 'refiner';
type ModelManagerState = {
_version: 1;

View File

@@ -10,6 +10,7 @@ import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
useCLIPEmbedModels,
useCLIPVisionModels,
useControlNetModels,
useEmbeddingModels,
useIPAdapterModels,
@@ -73,6 +74,12 @@ const ModelList = () => {
[ipAdapterModels, searchTerm, filteredModelType]
);
const [clipVisionModels, { isLoading: isLoadingCLIPVisionModels }] = useCLIPVisionModels();
const filteredCLIPVisionModels = useMemo(
() => modelsFilter(clipVisionModels, searchTerm, filteredModelType),
[clipVisionModels, searchTerm, filteredModelType]
);
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels();
const filteredVAEModels = useMemo(
() => modelsFilter(vaeModels, searchTerm, filteredModelType),
@@ -107,6 +114,7 @@ const ModelList = () => {
filteredControlNetModels.length +
filteredT2IAdapterModels.length +
filteredIPAdapterModels.length +
filteredCLIPVisionModels.length +
filteredVAEModels.length +
filteredSpandrelImageToImageModels.length +
t5EncoderModels.length +
@@ -116,6 +124,7 @@ const ModelList = () => {
filteredControlNetModels.length,
filteredEmbeddingModels.length,
filteredIPAdapterModels.length,
filteredCLIPVisionModels.length,
filteredLoRAModels.length,
filteredMainModels.length,
filteredRefinerModels.length,
@@ -171,6 +180,11 @@ const ModelList = () => {
{!isLoadingIPAdapterModels && filteredIPAdapterModels.length > 0 && (
<ModelListWrapper title={t('common.ipAdapter')} modelList={filteredIPAdapterModels} key="ip-adapters" />
)}
{/* CLIP Vision List */}
{isLoadingCLIPVisionModels && <FetchingModelsLoader loadingMessage="Loading CLIP Vision Models..." />}
{!isLoadingCLIPVisionModels && filteredCLIPVisionModels.length > 0 && (
<ModelListWrapper title="CLIP Vision" modelList={filteredCLIPVisionModels} key="clip-vision" />
)}
{/* T2I Adapters List */}
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (

View File

@@ -22,6 +22,7 @@ export const ModelTypeFilter = memo(() => {
t5_encoder: t('modelManager.t5Encoder'),
clip_embed: t('modelManager.clipEmbed'),
ip_adapter: t('common.ipAdapter'),
clip_vision: 'CLIP Vision',
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
}),
[t]

View File

@@ -120,14 +120,18 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
<Textarea {...form.register('description')} minH={32} />
</FormControl>
</Flex>
<Heading as="h3" fontSize="md" mt="4">
{t('modelManager.modelSettings')}
</Heading>
{modelConfig.type !== 'clip_vision' && (
<Heading as="h3" fontSize="md" mt="4">
{t('modelManager.modelSettings')}
</Heading>
)}
<SimpleGrid columns={2} gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={form.control} />
</FormControl>
{modelConfig.type !== 'clip_vision' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={form.control} />
</FormControl>
)}
{modelConfig.type === 'main' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>

View File

@@ -4,7 +4,7 @@ import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/TriggerPhrases';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
@@ -17,6 +17,20 @@ type Props = {
export const ModelView = memo(({ modelConfig }: Props) => {
const { t } = useTranslation();
const withSettings = useMemo(() => {
if (modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner') {
return true;
}
if (modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') {
return true;
}
if (modelConfig.type === 'main' || modelConfig.type === 'lora') {
return true;
}
return false;
}, [modelConfig.base, modelConfig.type]);
return (
<Flex flexDir="column" gap={4}>
<ModelHeader modelConfig={modelConfig}>
@@ -50,15 +64,19 @@ export const ModelView = memo(({ modelConfig }: Props) => {
)}
</SimpleGrid>
</Box>
<Box layerStyle="second" borderRadius="base" p={4}>
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
<MainModelDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') && (
<ControlNetOrT2IAdapterDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && <TriggerPhrases modelConfig={modelConfig} />}
</Box>
{withSettings && (
<Box layerStyle="second" borderRadius="base" p={4}>
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
<MainModelDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') && (
<ControlNetOrT2IAdapterDefaultSettings modelConfig={modelConfig} />
)}
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && (
<TriggerPhrases modelConfig={modelConfig} />
)}
</Box>
)}
</Flex>
</Flex>
);

View File

@@ -10,6 +10,7 @@ import {
import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isCLIPVisionModelConfig,
isControlNetModelConfig,
isControlNetOrT2IAdapterModelConfig,
isFluxMainModelModelConfig,
@@ -58,6 +59,7 @@ export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
export const useVAEModels = buildModelsHook(isVAEModelConfig);
export const useFluxVAEModels = buildModelsHook(isFluxVAEModelConfig);
export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig);
// const buildModelsSelector =
// <T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T): Selector<RootState, T[]> =>

View File

@@ -99,6 +99,10 @@ export const isIPAdapterModelConfig = (config: AnyModelConfig): config is IPAdap
return config.type === 'ip_adapter';
};
export const isCLIPVisionModelConfig = (config: AnyModelConfig): config is CLIPVisionDiffusersConfig => {
return config.type === 'clip_vision';
};
export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAdapterModelConfig => {
return config.type === 't2i_adapter';
};

View File

@@ -1 +1 @@
__version__ = "5.0.1"
__version__ = "5.0.2"

View File

@@ -5,6 +5,7 @@ from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils impo
is_state_dict_likely_in_flux_diffusers_format,
lora_model_from_flux_diffusers_state_dict,
)
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
@@ -50,6 +51,7 @@ def test_lora_model_from_flux_diffusers_state_dict():
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
assert len(model.layers) == len(expected_lora_layers)
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():

View File

@@ -5,12 +5,11 @@ import torch
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
FLUX_KOHYA_CLIP_PREFIX,
FLUX_KOHYA_TRANFORMER_PREFIX,
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
is_state_dict_likely_in_flux_kohya_format,
lora_model_from_flux_kohya_state_dict,
)
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
@@ -95,8 +94,8 @@ def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
expected_layer_keys: set[str] = set()
for k in sd_keys:
# Replace prefixes.
k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX)
k = k.replace("lora_te1_", FLUX_KOHYA_CLIP_PREFIX)
k = k.replace("lora_unet_", FLUX_LORA_TRANSFORMER_PREFIX)
k = k.replace("lora_te1_", FLUX_LORA_CLIP_PREFIX)
# Remove suffixes.
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")