feat(ui): add main model trigger phrases

This commit is contained in:
psychedelicious
2024-03-07 15:24:28 +11:00
parent 2f0a653a7f
commit fdecb886b2
3 changed files with 30 additions and 7 deletions

View File

@@ -64,7 +64,7 @@ export const ModelView = () => {
<DefaultSettings />
</Box>
)}
{data.type === 'lora' && (
{(data.type === 'main' || data.type === 'lora') && (
<Box layerStyle="second" borderRadius="base" p={4}>
<TriggerPhrases />
</Box>

View File

@@ -1,9 +1,11 @@
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { selectLoraSlice } from 'features/lora/store/loraSlice';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import type { PromptTriggerSelectProps } from 'features/prompt/types';
import { t } from 'i18next';
import { flatten, map } from 'lodash-es';
@@ -13,18 +15,23 @@ import {
loraModelsAdapterSelectors,
textualInversionModelsAdapterSelectors,
useGetLoRAModelsQuery,
useGetModelConfigQuery,
useGetTextualInversionModelsQuery,
} from 'services/api/endpoints/models';
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
const selectLoRAs = createMemoizedSelector(selectLoraSlice, (loras) => loras.loras);
const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
const { t } = useTranslation();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const mainModel = useAppSelector(selectMainModel);
const addedLoRAs = useAppSelector(selectLoRAs);
const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
mainModel?.key ?? skipToken
);
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
@@ -46,7 +53,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
if (tiModels) {
const embeddingOptions = textualInversionModelsAdapterSelectors
.selectAll(tiModels)
.filter((ti) => ti.base === currentBaseModel)
.filter((ti) => ti.base === mainModelConfig?.base)
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
if (embeddingOptions.length > 0) {
@@ -71,18 +78,33 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
if (triggerPhraseOptions.length > 0) {
_options.push({
label: t('modelManager.triggerPhrases'),
label: t('modelManager.loraTriggerPhrases'),
options: flatten(triggerPhraseOptions),
});
}
}
if (mainModelConfig && mainModelConfig.trigger_phrases?.length) {
_options.push({
label: t('modelManager.mainModelTriggerPhrases'),
options: mainModelConfig.trigger_phrases.map((triggerPhrase) => ({
label: triggerPhrase,
value: triggerPhrase,
})),
});
}
return _options;
}, [tiModels, loraModels, t, currentBaseModel, addedLoRAs]);
}, [tiModels, loraModels, mainModelConfig, t, addedLoRAs]);
return (
<FormControl>
<Combobox
placeholder={isLoadingLoRAs || isLoadingTIs ? t('common.loading') : t('prompt.addPromptTrigger')}
placeholder={
isLoadingLoRAs || isLoadingTIs || isLoadingMainModelConfig
? t('common.loading')
: t('prompt.addPromptTrigger')
}
defaultMenuIsOpen
autoFocus
value={null}