Compare commits

..

4 Commits

Author SHA1 Message Date
psychedelicious
e04c25eba7 Merge branch 'main' into fix/diffusers-embeddings 2023-08-01 22:14:43 +10:00
Lincoln Stein
be61ffdbf6 Merge branch 'main' into fix/diffusers-embeddings 2023-08-01 00:21:50 -04:00
Lincoln Stein
823b879329 Merge branch 'main' into fix/diffusers-embeddings 2023-07-31 21:04:08 -04:00
Lincoln Stein
17c901aaf7 fix diffusers-style textual embeddings
- also fix a couple places where the wrong base was used for relative model paths
2023-07-31 21:00:12 -04:00
8 changed files with 178 additions and 297 deletions

View File

@@ -34,10 +34,6 @@
cudaPackages.cudnn cudaPackages.cudnn
cudaPackages.cuda_nvrtc cudaPackages.cuda_nvrtc
cudatoolkit cudatoolkit
pkgconfig
libconfig
cmake
blas
freeglut freeglut
glib glib
gperf gperf
@@ -46,12 +42,6 @@
libGLU libGLU
linuxPackages.nvidia_x11 linuxPackages.nvidia_x11
python python
(opencv4.override {
enableGtk3 = true;
enableFfmpeg = true;
enableCuda = true;
enableUnfree = true;
})
stdenv.cc stdenv.cc
stdenv.cc.cc.lib stdenv.cc.cc.lib
xorg.libX11 xorg.libX11

View File

@@ -108,15 +108,14 @@ class CompelInvocation(BaseInvocation):
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append(( ti_list.append(
name,
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=self.clip.text_encoder.base_model, base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context, context=context,
).context.model ).context.model
)) )
except ModelNotFoundException: except ModelNotFoundException:
# print(e) # print(e)
# import traceback # import traceback
@@ -197,15 +196,14 @@ class SDXLPromptInvocationBase:
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append(( ti_list.append(
name,
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context, context=context,
).context.model ).context.model
)) )
except ModelNotFoundException: except ModelNotFoundException:
# print(e) # print(e)
# import traceback # import traceback
@@ -272,15 +270,14 @@ class SDXLPromptInvocationBase:
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append(( ti_list.append(
name,
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context, context=context,
).context.model ).context.model
)) )
except ModelNotFoundException: except ModelNotFoundException:
# print(e) # print(e)
# import traceback # import traceback

View File

@@ -65,6 +65,7 @@ class ONNXPromptInvocation(BaseInvocation):
**self.clip.text_encoder.dict(), **self.clip.text_encoder.dict(),
) )
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack: with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [ loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
for lora in self.clip.loras for lora in self.clip.loras
@@ -74,14 +75,20 @@ class ONNXPromptInvocation(BaseInvocation):
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append(( ti_list.append(
name, # stack.enter_context(
# context.services.model_manager.get_model(
# model_name=name,
# base_model=self.clip.text_encoder.base_model,
# model_type=ModelType.TextualInversion,
# )
# )
context.services.model_manager.get_model( context.services.model_manager.get_model(
model_name=name, model_name=name,
base_model=self.clip.text_encoder.base_model, base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
).context.model ).context.model
)) )
except Exception: except Exception:
# print(e) # print(e)
# import traceback # import traceback

View File

@@ -562,7 +562,7 @@ class ModelPatcher:
cls, cls,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
ti_list: List[Tuple[str, Any]], ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]: ) -> Tuple[CLIPTokenizer, TextualInversionManager]:
init_tokens_count = None init_tokens_count = None
new_tokens_added = None new_tokens_added = None
@@ -572,27 +572,27 @@ class ModelPatcher:
ti_manager = TextualInversionManager(ti_tokenizer) ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
def _get_trigger(ti_name, index): def _get_trigger(ti, index):
trigger = ti_name trigger = ti.name
if index > 0: if index > 0:
trigger += f"-!pad-{i}" trigger += f"-!pad-{i}"
return f"<{trigger}>" return f"<{trigger}>"
# modify tokenizer # modify tokenizer
new_tokens_added = 0 new_tokens_added = 0
for ti_name, ti in ti_list: for ti in ti_list:
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder # modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added) text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
model_embeddings = text_encoder.get_input_embeddings() model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list: for ti in ti_list:
ti_tokens = [] ti_tokens = []
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i] embedding = ti.embedding[i]
trigger = _get_trigger(ti_name, i) trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger) token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id: if token_id == ti_tokenizer.unk_token_id:
@@ -637,6 +637,7 @@ class ModelPatcher:
class TextualInversionModel: class TextualInversionModel:
name: str
embedding: torch.Tensor # [n, 768]|[n, 1280] embedding: torch.Tensor # [n, 768]|[n, 1280]
@classmethod @classmethod
@@ -650,6 +651,10 @@ class TextualInversionModel:
file_path = Path(file_path) file_path = Path(file_path)
result = cls() # TODO: result = cls() # TODO:
if file_path.name == "learned_embeds.bin":
result.name = file_path.parent.name
else:
result.name = file_path.stem
if file_path.suffix == ".safetensors": if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu") state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
@@ -826,7 +831,7 @@ class ONNXModelPatcher:
cls, cls,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel, text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Tuple[str, Any]], ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]: ) -> Tuple[CLIPTokenizer, TextualInversionManager]:
from .models.base import IAIOnnxRuntimeModel from .models.base import IAIOnnxRuntimeModel
@@ -839,17 +844,17 @@ class ONNXModelPatcher:
ti_tokenizer = copy.deepcopy(tokenizer) ti_tokenizer = copy.deepcopy(tokenizer)
ti_manager = TextualInversionManager(ti_tokenizer) ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti_name, index): def _get_trigger(ti, index):
trigger = ti_name trigger = ti.name
if index > 0: if index > 0:
trigger += f"-!pad-{i}" trigger += f"-!pad-{i}"
return f"<{trigger}>" return f"<{trigger}>"
# modify tokenizer # modify tokenizer
new_tokens_added = 0 new_tokens_added = 0
for ti_name, ti in ti_list: for ti in ti_list:
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder # modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
@@ -859,11 +864,11 @@ class ONNXModelPatcher:
axis=0, axis=0,
) )
for ti_name, ti in ti_list: for ti in ti_list:
ti_tokens = [] ti_tokens = []
for i in range(ti.embedding.shape[0]): for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i].detach().numpy() embedding = ti.embedding[i].detach().numpy()
trigger = _get_trigger(ti_name, i) trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger) token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id: if token_id == ti_tokenizer.unk_token_id:

View File

@@ -188,7 +188,7 @@ class ModelCache(object):
cache_entry = self._cached_models.get(key, None) cache_entry = self._cached_models.get(key, None)
if cache_entry is None: if cache_entry is None:
self.logger.info( self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}" f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
) )
# this will remove older cached models until # this will remove older cached models until
@@ -210,31 +210,6 @@ class ModelCache(object):
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size) return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
def clear_one_model(self) -> bool:
reserved = self.max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
smallest_key = None
smallest_size = float("inf")
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if not cache_entry.locked and cache_entry.loaded:
if cache_entry.size > 0 and cache_entry.size < smallest_size:
smallest_key = model_key
smallest_size = cache_entry.size
if smallest_key is not None:
cache_entry = self._cached_models[smallest_key]
self.logger.debug(f"!!!!!!!!!!!Offloading {smallest_key} from {self.execution_device} into {self.storage_device}")
with VRAMUsage() as mem:
cache_entry.model.to(self.storage_device)
self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB")
vram_in_use += mem.vram_used # note vram_used is negative
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
torch.cuda.empty_cache()
gc.collect()
return smallest_key is not None
class ModelLocker(object): class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load, size_needed): def __init__(self, cache, key, model, gpu_load, size_needed):
""" """
@@ -261,48 +236,17 @@ class ModelCache(object):
self.cache_entry.lock() self.cache_entry.lock()
try: try:
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}") if self.cache.lazy_offloading:
while True: self.cache._offload_unlocked_models(self.size_needed)
try:
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}") if self.model.device != self.cache.execution_device:
self.cache._print_cuda_stats() self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
def my_forward(module, cache, *args, **kwargs): self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
while True: self.cache._print_cuda_stats()
try:
return module._orig_forward(*args, **kwargs)
except:
if not cache.clear_one_model():
raise
import functools
from diffusers.models.unet_2d_blocks import DownBlock2D, CrossAttnDownBlock2D, UpBlock2D, CrossAttnUpBlock2D
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
from diffusers.models.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D
for module_name, module in self.model.named_modules():
if type(module) not in [
DownBlock2D, CrossAttnDownBlock2D, UpBlock2D, CrossAttnUpBlock2D, # unet blocks
CLIPEncoderLayer, # CLIPTextTransformer clip
DownEncoderBlock2D, UpDecoderBlock2D, # vae
]:
continue
# better here filter to only specific model modules
module._orig_forward = module.forward
module.forward = functools.partial(my_forward, module, self.cache)
self.model._orig_forward = self.model.forward
self.model.forward = functools.partial(my_forward, self.model, self.cache)
break
except:
if not self.cache.clear_one_model():
raise
except: except:
self.cache_entry.unlock() self.cache_entry.unlock()
@@ -320,19 +264,10 @@ class ModelCache(object):
if not hasattr(self.model, "to"): if not hasattr(self.model, "to"):
return return
if hasattr(self.model, "_orig_forward"):
self.model.forward = self.model._orig_forward
delattr(self.model, "_orig_forward")
for module_name, module in self.model.named_modules():
if hasattr(module, "_orig_forward"):
module.forward = module._orig_forward
delattr(module, "_orig_forward")
self.cache_entry.unlock() self.cache_entry.unlock()
#if not self.cache.lazy_offloading: if not self.cache.lazy_offloading:
# self.cache._offload_unlocked_models() self.cache._offload_unlocked_models()
# self.cache._print_cuda_stats() self.cache._print_cuda_stats()
# TODO: should it be called untrack_model? # TODO: should it be called untrack_model?
def uncache_model(self, cache_id: str): def uncache_model(self, cache_id: str):

View File

@@ -472,7 +472,7 @@ class ModelManager(object):
if submodel_type is not None and hasattr(model_config, submodel_type): if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type) override_path = getattr(model_config, submodel_type)
if override_path: if override_path:
model_path = self.app_config.root_path / override_path model_path = self.resolve_path(override_path)
model_type = submodel_type model_type = submodel_type
submodel_type = None submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
@@ -670,7 +670,7 @@ class ModelManager(object):
# TODO: if path changed and old_model.path inside models folder should we delete this too? # TODO: if path changed and old_model.path inside models folder should we delete this too?
# remove conversion cache as config changed # remove conversion cache as config changed
old_model_path = self.app_config.root_path / old_model.path old_model_path = self.resolve_model_path(old_model.path)
old_model_cache = self._get_model_cache_path(old_model_path) old_model_cache = self._get_model_cache_path(old_model_path)
if old_model_cache.exists(): if old_model_cache.exists():
if old_model_cache.is_dir(): if old_model_cache.is_dir():
@@ -780,7 +780,7 @@ class ModelManager(object):
model_type, model_type,
**submodel, **submodel,
) )
checkpoint_path = self.app_config.root_path / info["path"] checkpoint_path = self.resolve_model_path(info["path"])
old_diffusers_path = self.resolve_model_path(model.location) old_diffusers_path = self.resolve_model_path(model.location)
new_diffusers_path = ( new_diffusers_path = (
dest_directory or self.app_config.models_path / base_model.value / model_type.value dest_directory or self.app_config.models_path / base_model.value / model_type.value
@@ -992,7 +992,7 @@ class ModelManager(object):
model_manager=self, model_manager=self,
prediction_type_helper=ask_user_for_prediction_type, prediction_type_helper=ask_user_for_prediction_type,
) )
known_paths = {config.root_path / x["path"] for x in self.list_models()} known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
directories = { directories = {
config.root_path / x config.root_path / x
for x in [ for x in [

View File

@@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Spinner, Text } from '@chakra-ui/react'; import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit'; import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
@@ -6,23 +6,23 @@ import { forEach } from 'lodash-es';
import type { ChangeEvent, PropsWithChildren } from 'react'; import type { ChangeEvent, PropsWithChildren } from 'react';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { import {
LoRAModelConfigEntity,
MainModelConfigEntity, MainModelConfigEntity,
OnnxModelConfigEntity, OnnxModelConfigEntity,
useGetLoRAModelsQuery,
useGetMainModelsQuery, useGetMainModelsQuery,
useGetOnnxModelsQuery, useGetOnnxModelsQuery,
useGetLoRAModelsQuery,
LoRAModelConfigEntity,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
type ModelListProps = { type ModelListProps = {
selectedModelId: string | undefined; selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void; setSelectedModelId: (name: string | undefined) => void;
}; };
type ModelFormat = 'all' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx'; type ModelFormat = 'images' | 'checkpoint' | 'diffusers' | 'olive' | 'onnx';
type ModelType = 'main' | 'lora' | 'onnx'; type ModelType = 'main' | 'lora' | 'onnx';
@@ -33,63 +33,47 @@ const ModelList = (props: ModelListProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>(''); const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] = const [modelFormatFilter, setModelFormatFilter] =
useState<CombinedModelFormat>('all'); useState<CombinedModelFormat>('images');
const { filteredDiffusersModels, isLoadingDiffusersModels } = const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({
selectFromResult: ({ data, isLoading }) => ({ filteredDiffusersModels: modelsFilter(
filteredDiffusersModels: modelsFilter( data,
data, 'main',
'main', 'diffusers',
'diffusers', nameFilter
nameFilter ),
), }),
isLoadingDiffusersModels: isLoading, });
}),
});
const { filteredCheckpointModels, isLoadingCheckpointModels } = const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
useGetMainModelsQuery(ALL_BASE_MODELS, { selectFromResult: ({ data }) => ({
selectFromResult: ({ data, isLoading }) => ({ filteredCheckpointModels: modelsFilter(
filteredCheckpointModels: modelsFilter( data,
data, 'main',
'main', 'checkpoint',
'checkpoint', nameFilter
nameFilter ),
), }),
isLoadingCheckpointModels: isLoading, });
}),
});
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery( const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
undefined, selectFromResult: ({ data }) => ({
{ filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
selectFromResult: ({ data, isLoading }) => ({ }),
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter), });
isLoadingLoraModels: isLoading,
}),
}
);
const { filteredOnnxModels, isLoadingOnnxModels } = useGetOnnxModelsQuery( const { filteredOnnxModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
ALL_BASE_MODELS, selectFromResult: ({ data }) => ({
{ filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter),
selectFromResult: ({ data, isLoading }) => ({ }),
filteredOnnxModels: modelsFilter(data, 'onnx', 'onnx', nameFilter), });
isLoadingOnnxModels: isLoading,
}),
}
);
const { filteredOliveModels, isLoadingOliveModels } = useGetOnnxModelsQuery( const { filteredOliveModels } = useGetOnnxModelsQuery(ALL_BASE_MODELS, {
ALL_BASE_MODELS, selectFromResult: ({ data }) => ({
{ filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter),
selectFromResult: ({ data, isLoading }) => ({ }),
filteredOliveModels: modelsFilter(data, 'onnx', 'olive', nameFilter), });
isLoadingOliveModels: isLoading,
}),
}
);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => { const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value); setNameFilter(e.target.value);
@@ -100,8 +84,8 @@ const ModelList = (props: ModelListProps) => {
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}> <Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
<ButtonGroup isAttached> <ButtonGroup isAttached>
<IAIButton <IAIButton
onClick={() => setModelFormatFilter('all')} onClick={() => setModelFormatFilter('images')}
isChecked={modelFormatFilter === 'all'} isChecked={modelFormatFilter === 'images'}
size="sm" size="sm"
> >
{t('modelManager.allModels')} {t('modelManager.allModels')}
@@ -155,76 +139,95 @@ const ModelList = (props: ModelListProps) => {
maxHeight={window.innerHeight - 280} maxHeight={window.innerHeight - 280}
overflow="scroll" overflow="scroll"
> >
{/* Diffusers List */} {['images', 'diffusers'].includes(modelFormatFilter) &&
{isLoadingDiffusersModels && (
<FetchingModelsLoader loadingMessage="Loading Diffusers..." />
)}
{['all', 'diffusers'].includes(modelFormatFilter) &&
!isLoadingDiffusersModels &&
filteredDiffusersModels.length > 0 && ( filteredDiffusersModels.length > 0 && (
<ModelListWrapper <StyledModelContainer>
title="Diffusers" <Flex sx={{ gap: 2, flexDir: 'column' }}>
modelList={filteredDiffusersModels} <Text variant="subtext" fontSize="sm">
selected={{ selectedModelId, setSelectedModelId }} Diffusers
key="diffusers" </Text>
/> {filteredDiffusersModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
{/* Checkpoints List */} {['images', 'checkpoint'].includes(modelFormatFilter) &&
{isLoadingCheckpointModels && (
<FetchingModelsLoader loadingMessage="Loading Checkpoints..." />
)}
{['all', 'checkpoint'].includes(modelFormatFilter) &&
!isLoadingCheckpointModels &&
filteredCheckpointModels.length > 0 && ( filteredCheckpointModels.length > 0 && (
<ModelListWrapper <StyledModelContainer>
title="Checkpoints" <Flex sx={{ gap: 2, flexDir: 'column' }}>
modelList={filteredCheckpointModels} <Text variant="subtext" fontSize="sm">
selected={{ selectedModelId, setSelectedModelId }} Checkpoints
key="checkpoints" </Text>
/> {filteredCheckpointModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
{['images', 'olive'].includes(modelFormatFilter) &&
{/* LoRAs List */}
{isLoadingLoraModels && (
<FetchingModelsLoader loadingMessage="Loading LoRAs..." />
)}
{['all', 'lora'].includes(modelFormatFilter) &&
!isLoadingLoraModels &&
filteredLoraModels.length > 0 && (
<ModelListWrapper
title="LoRAs"
modelList={filteredLoraModels}
selected={{ selectedModelId, setSelectedModelId }}
key="loras"
/>
)}
{/* Olive List */}
{isLoadingOliveModels && (
<FetchingModelsLoader loadingMessage="Loading Olives..." />
)}
{['all', 'olive'].includes(modelFormatFilter) &&
!isLoadingOliveModels &&
filteredOliveModels.length > 0 && ( filteredOliveModels.length > 0 && (
<ModelListWrapper <StyledModelContainer>
title="Olives" <Flex sx={{ gap: 2, flexDir: 'column' }}>
modelList={filteredOliveModels} <Text variant="subtext" fontSize="sm">
selected={{ selectedModelId, setSelectedModelId }} Olives
key="olive" </Text>
/> {filteredOliveModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
{/* Onnx List */} {['images', 'onnx'].includes(modelFormatFilter) &&
{isLoadingOnnxModels && (
<FetchingModelsLoader loadingMessage="Loading ONNX..." />
)}
{['all', 'onnx'].includes(modelFormatFilter) &&
!isLoadingOnnxModels &&
filteredOnnxModels.length > 0 && ( filteredOnnxModels.length > 0 && (
<ModelListWrapper <StyledModelContainer>
title="ONNX" <Flex sx={{ gap: 2, flexDir: 'column' }}>
modelList={filteredOnnxModels} <Text variant="subtext" fontSize="sm">
selected={{ selectedModelId, setSelectedModelId }} Onnx
key="onnx" </Text>
/> {filteredOnnxModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)}
{['images', 'lora'].includes(modelFormatFilter) &&
filteredLoraModels.length > 0 && (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
LoRAs
</Text>
{filteredLoraModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
)} )}
</Flex> </Flex>
</Flex> </Flex>
@@ -284,52 +287,3 @@ const StyledModelContainer = (props: PropsWithChildren) => {
</Flex> </Flex>
); );
}; };
type ModelListWrapperProps = {
title: string;
modelList:
| MainModelConfigEntity[]
| LoRAModelConfigEntity[]
| OnnxModelConfigEntity[];
selected: ModelListProps;
};
function ModelListWrapper(props: ModelListWrapperProps) {
const { title, modelList, selected } = props;
return (
<StyledModelContainer>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
{title}
</Text>
{modelList.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selected.selectedModelId === model.id}
setSelectedModelId={selected.setSelectedModelId}
/>
))}
</Flex>
</StyledModelContainer>
);
}
function FetchingModelsLoader({ loadingMessage }: { loadingMessage?: string }) {
return (
<StyledModelContainer>
<Flex
justifyContent="center"
alignItems="center"
flexDirection="column"
p={4}
gap={8}
>
<Spinner />
<Text variant="subtext">
{loadingMessage ? loadingMessage : 'Fetching...'}
</Text>
</Flex>
</StyledModelContainer>
);
}

View File

@@ -181,7 +181,7 @@ export const modelsApi = api.injectEndpoints({
}, },
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ type: 'OnnxModel', id: LIST_TAG }, { id: 'OnnxModel', type: LIST_TAG },
]; ];
if (result) { if (result) {
@@ -266,7 +266,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
importMainModels: build.mutation< importMainModels: build.mutation<
@@ -283,7 +282,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({ addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
@@ -297,7 +295,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
deleteMainModels: build.mutation< deleteMainModels: build.mutation<
@@ -313,7 +310,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
convertMainModels: build.mutation< convertMainModels: build.mutation<
@@ -330,7 +326,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({ mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
@@ -344,7 +339,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
syncModels: build.mutation<SyncModelsResponse, void>({ syncModels: build.mutation<SyncModelsResponse, void>({
@@ -357,7 +351,6 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: [ invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG }, { type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'OnnxModel', id: LIST_TAG },
], ],
}), }),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({