mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 17:37:55 -05:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8957aa50d | ||
|
|
807f458f13 | ||
|
|
68dbe45315 | ||
|
|
bd3d1dcdf9 | ||
|
|
386c01ede1 |
@@ -30,7 +30,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
@@ -209,7 +209,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
@@ -220,7 +220,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
LoRAPatcher.apply_lora_sidecar_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
from invokeai.backend.lora.lora_patcher import LoRAPatcher
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
@@ -101,7 +101,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
LoRAPatcher.apply_lora_patches(
|
||||
model=clip_text_encoder,
|
||||
patches=self._clip_lora_iterator(context),
|
||||
prefix=FLUX_KOHYA_CLIP_PREFIX,
|
||||
prefix=FLUX_LORA_CLIP_PREFIX,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.lora.layers.lora_layer import LoRALayer
|
||||
@@ -189,7 +190,9 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
# Assert that all keys were processed.
|
||||
assert len(grouped_state_dict) == 0
|
||||
|
||||
return LoRAModelRaw(layers=layers)
|
||||
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
|
||||
|
||||
return LoRAModelRaw(layers=layers_with_prefix)
|
||||
|
||||
|
||||
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
|
||||
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
|
||||
@@ -23,11 +24,6 @@ FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
|
||||
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
|
||||
|
||||
|
||||
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the InvokeAI LoRA format.
|
||||
FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-"
|
||||
FLUX_KOHYA_CLIP_PREFIX = "lora_clip-"
|
||||
|
||||
|
||||
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
|
||||
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
|
||||
|
||||
@@ -67,9 +63,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
|
||||
# Create LoRA layers.
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
for layer_key, layer_state_dict in transformer_grouped_sd.items():
|
||||
layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
for layer_key, layer_state_dict in clip_grouped_sd.items():
|
||||
layers[FLUX_KOHYA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
|
||||
# Create and return the LoRAModelRaw.
|
||||
return LoRAModelRaw(layers=layers)
|
||||
|
||||
3
invokeai/backend/lora/conversions/flux_lora_constants.py
Normal file
3
invokeai/backend/lora/conversions/flux_lora_constants.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the FLUX InvokeAI LoRA format.
|
||||
FLUX_LORA_TRANSFORMER_PREFIX = "lora_transformer-"
|
||||
FLUX_LORA_CLIP_PREFIX = "lora_clip-"
|
||||
@@ -3,7 +3,7 @@ import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { ModelType } from 'services/api/types';
|
||||
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
|
||||
export type FilterableModelType = Exclude<ModelType, 'onnx'> | 'refiner';
|
||||
|
||||
type ModelManagerState = {
|
||||
_version: 1;
|
||||
|
||||
@@ -10,6 +10,7 @@ import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
useCLIPEmbedModels,
|
||||
useCLIPVisionModels,
|
||||
useControlNetModels,
|
||||
useEmbeddingModels,
|
||||
useIPAdapterModels,
|
||||
@@ -73,6 +74,12 @@ const ModelList = () => {
|
||||
[ipAdapterModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [clipVisionModels, { isLoading: isLoadingCLIPVisionModels }] = useCLIPVisionModels();
|
||||
const filteredCLIPVisionModels = useMemo(
|
||||
() => modelsFilter(clipVisionModels, searchTerm, filteredModelType),
|
||||
[clipVisionModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels();
|
||||
const filteredVAEModels = useMemo(
|
||||
() => modelsFilter(vaeModels, searchTerm, filteredModelType),
|
||||
@@ -107,6 +114,7 @@ const ModelList = () => {
|
||||
filteredControlNetModels.length +
|
||||
filteredT2IAdapterModels.length +
|
||||
filteredIPAdapterModels.length +
|
||||
filteredCLIPVisionModels.length +
|
||||
filteredVAEModels.length +
|
||||
filteredSpandrelImageToImageModels.length +
|
||||
t5EncoderModels.length +
|
||||
@@ -116,6 +124,7 @@ const ModelList = () => {
|
||||
filteredControlNetModels.length,
|
||||
filteredEmbeddingModels.length,
|
||||
filteredIPAdapterModels.length,
|
||||
filteredCLIPVisionModels.length,
|
||||
filteredLoRAModels.length,
|
||||
filteredMainModels.length,
|
||||
filteredRefinerModels.length,
|
||||
@@ -171,6 +180,11 @@ const ModelList = () => {
|
||||
{!isLoadingIPAdapterModels && filteredIPAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title={t('common.ipAdapter')} modelList={filteredIPAdapterModels} key="ip-adapters" />
|
||||
)}
|
||||
{/* CLIP Vision List */}
|
||||
{isLoadingCLIPVisionModels && <FetchingModelsLoader loadingMessage="Loading CLIP Vision Models..." />}
|
||||
{!isLoadingCLIPVisionModels && filteredCLIPVisionModels.length > 0 && (
|
||||
<ModelListWrapper title="CLIP Vision" modelList={filteredCLIPVisionModels} key="clip-vision" />
|
||||
)}
|
||||
{/* T2I Adapters List */}
|
||||
{isLoadingT2IAdapterModels && <FetchingModelsLoader loadingMessage="Loading T2I Adapters..." />}
|
||||
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||
|
||||
@@ -22,6 +22,7 @@ export const ModelTypeFilter = memo(() => {
|
||||
t5_encoder: t('modelManager.t5Encoder'),
|
||||
clip_embed: t('modelManager.clipEmbed'),
|
||||
ip_adapter: t('common.ipAdapter'),
|
||||
clip_vision: 'CLIP Vision',
|
||||
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
|
||||
}),
|
||||
[t]
|
||||
|
||||
@@ -120,14 +120,18 @@ export const ModelEdit = memo(({ modelConfig }: Props) => {
|
||||
<Textarea {...form.register('description')} minH={32} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
{modelConfig.type !== 'clip_vision' && (
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
)}
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={form.control} />
|
||||
</FormControl>
|
||||
{modelConfig.type !== 'clip_vision' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={form.control} />
|
||||
</FormControl>
|
||||
)}
|
||||
{modelConfig.type === 'main' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
|
||||
@@ -4,7 +4,7 @@ import { ModelConvertButton } from 'features/modelManagerV2/subpanels/ModelPanel
|
||||
import { ModelEditButton } from 'features/modelManagerV2/subpanels/ModelPanel/ModelEditButton';
|
||||
import { ModelHeader } from 'features/modelManagerV2/subpanels/ModelPanel/ModelHeader';
|
||||
import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/TriggerPhrases';
|
||||
import { memo } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
@@ -17,6 +17,20 @@ type Props = {
|
||||
|
||||
export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const withSettings = useMemo(() => {
|
||||
if (modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner') {
|
||||
return true;
|
||||
}
|
||||
if (modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') {
|
||||
return true;
|
||||
}
|
||||
if (modelConfig.type === 'main' || modelConfig.type === 'lora') {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}, [modelConfig.base, modelConfig.type]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<ModelHeader modelConfig={modelConfig}>
|
||||
@@ -50,15 +64,19 @@ export const ModelView = memo(({ modelConfig }: Props) => {
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
|
||||
<MainModelDefaultSettings modelConfig={modelConfig} />
|
||||
)}
|
||||
{(modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') && (
|
||||
<ControlNetOrT2IAdapterDefaultSettings modelConfig={modelConfig} />
|
||||
)}
|
||||
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && <TriggerPhrases modelConfig={modelConfig} />}
|
||||
</Box>
|
||||
{withSettings && (
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
{modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner' && (
|
||||
<MainModelDefaultSettings modelConfig={modelConfig} />
|
||||
)}
|
||||
{(modelConfig.type === 'controlnet' || modelConfig.type === 't2i_adapter') && (
|
||||
<ControlNetOrT2IAdapterDefaultSettings modelConfig={modelConfig} />
|
||||
)}
|
||||
{(modelConfig.type === 'main' || modelConfig.type === 'lora') && (
|
||||
<TriggerPhrases modelConfig={modelConfig} />
|
||||
)}
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isCLIPEmbedModelConfig,
|
||||
isCLIPVisionModelConfig,
|
||||
isControlNetModelConfig,
|
||||
isControlNetOrT2IAdapterModelConfig,
|
||||
isFluxMainModelModelConfig,
|
||||
@@ -58,6 +59,7 @@ export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
||||
export const useFluxVAEModels = buildModelsHook(isFluxVAEModelConfig);
|
||||
export const useCLIPVisionModels = buildModelsHook(isCLIPVisionModelConfig);
|
||||
|
||||
// const buildModelsSelector =
|
||||
// <T extends AnyModelConfig>(typeGuard: (config: AnyModelConfig) => config is T): Selector<RootState, T[]> =>
|
||||
|
||||
@@ -99,6 +99,10 @@ export const isIPAdapterModelConfig = (config: AnyModelConfig): config is IPAdap
|
||||
return config.type === 'ip_adapter';
|
||||
};
|
||||
|
||||
export const isCLIPVisionModelConfig = (config: AnyModelConfig): config is CLIPVisionDiffusersConfig => {
|
||||
return config.type === 'clip_vision';
|
||||
};
|
||||
|
||||
export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAdapterModelConfig => {
|
||||
return config.type === 't2i_adapter';
|
||||
};
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "5.0.1"
|
||||
__version__ = "5.0.2"
|
||||
|
||||
@@ -5,6 +5,7 @@ from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils impo
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
lora_model_from_flux_diffusers_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||
)
|
||||
@@ -50,6 +51,7 @@ def test_lora_model_from_flux_diffusers_state_dict():
|
||||
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
|
||||
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
|
||||
assert len(model.layers) == len(expected_lora_layers)
|
||||
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
|
||||
|
||||
|
||||
def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():
|
||||
|
||||
@@ -5,12 +5,11 @@ import torch
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
|
||||
FLUX_KOHYA_CLIP_PREFIX,
|
||||
FLUX_KOHYA_TRANFORMER_PREFIX,
|
||||
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
lora_model_from_flux_kohya_state_dict,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||
)
|
||||
@@ -95,8 +94,8 @@ def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
|
||||
expected_layer_keys: set[str] = set()
|
||||
for k in sd_keys:
|
||||
# Replace prefixes.
|
||||
k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX)
|
||||
k = k.replace("lora_te1_", FLUX_KOHYA_CLIP_PREFIX)
|
||||
k = k.replace("lora_unet_", FLUX_LORA_TRANSFORMER_PREFIX)
|
||||
k = k.replace("lora_te1_", FLUX_LORA_CLIP_PREFIX)
|
||||
# Remove suffixes.
|
||||
k = k.replace(".lora_up.weight", "")
|
||||
k = k.replace(".lora_down.weight", "")
|
||||
|
||||
Reference in New Issue
Block a user