From fdecb886b2dc2a436fedc9ee0e4f4958134f5a30 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Thu, 7 Mar 2024 15:24:28 +1100
Subject: [PATCH] feat(ui): add main model trigger phrases
---
invokeai/frontend/web/public/locales/en.json | 3 +-
.../subpanels/ModelPanel/ModelView.tsx | 2 +-
.../features/prompt/PromptTriggerSelect.tsx | 32 ++++++++++++++++---
3 files changed, 30 insertions(+), 7 deletions(-)
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 223a26d04d..fe31306e12 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -855,7 +855,8 @@
"statusConverting": "Converting",
"syncModels": "Sync Models",
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
- "triggerPhrases": "Trigger Phrases",
+ "loraTriggerPhrases": "LoRA Trigger Phrases",
+ "mainModelTriggerPhrases": "Main Model Trigger Phrases",
"typePhraseHere": "Type phrase here",
"upcastAttention": "Upcast Attention",
"uploadImage": "Upload Image",
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx
index abcc46b3aa..a7d2e61f7b 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx
@@ -64,7 +64,7 @@ export const ModelView = () => {
)}
- {data.type === 'lora' && (
+ {(data.type === 'main' || data.type === 'lora') && (
diff --git a/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx b/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx
index 2b7f17293c..06055af9bd 100644
--- a/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx
+++ b/invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx
@@ -1,9 +1,11 @@
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
+import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { selectLoraSlice } from 'features/lora/store/loraSlice';
+import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import type { PromptTriggerSelectProps } from 'features/prompt/types';
import { t } from 'i18next';
import { flatten, map } from 'lodash-es';
@@ -13,18 +15,23 @@ import {
loraModelsAdapterSelectors,
textualInversionModelsAdapterSelectors,
useGetLoRAModelsQuery,
+ useGetModelConfigQuery,
useGetTextualInversionModelsQuery,
} from 'services/api/endpoints/models';
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
const selectLoRAs = createMemoizedSelector(selectLoraSlice, (loras) => loras.loras);
+const selectMainModel = createMemoizedSelector(selectGenerationSlice, (generation) => generation.model);
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
const { t } = useTranslation();
- const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
+ const mainModel = useAppSelector(selectMainModel);
const addedLoRAs = useAppSelector(selectLoRAs);
+ const { data: mainModelConfig, isLoading: isLoadingMainModelConfig } = useGetModelConfigQuery(
+ mainModel?.key ?? skipToken
+ );
const { data: loraModels, isLoading: isLoadingLoRAs } = useGetLoRAModelsQuery();
const { data: tiModels, isLoading: isLoadingTIs } = useGetTextualInversionModelsQuery();
@@ -46,7 +53,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
if (tiModels) {
const embeddingOptions = textualInversionModelsAdapterSelectors
.selectAll(tiModels)
- .filter((ti) => ti.base === currentBaseModel)
+ .filter((ti) => ti.base === mainModelConfig?.base)
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
if (embeddingOptions.length > 0) {
@@ -71,18 +78,33 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
if (triggerPhraseOptions.length > 0) {
_options.push({
- label: t('modelManager.triggerPhrases'),
+ label: t('modelManager.loraTriggerPhrases'),
options: flatten(triggerPhraseOptions),
});
}
}
+
+ if (mainModelConfig && mainModelConfig.trigger_phrases?.length) {
+ _options.push({
+ label: t('modelManager.mainModelTriggerPhrases'),
+ options: mainModelConfig.trigger_phrases.map((triggerPhrase) => ({
+ label: triggerPhrase,
+ value: triggerPhrase,
+ })),
+ });
+ }
+
return _options;
- }, [tiModels, loraModels, t, currentBaseModel, addedLoRAs]);
+ }, [tiModels, loraModels, mainModelConfig, t, addedLoRAs]);
return (