Added related model support

This commit is contained in:
Kent Keirsey
2025-07-02 14:58:20 -04:00
committed by psychedelicious
parent c10865c7ef
commit 49ae66d94a

View File

@@ -1,5 +1,5 @@
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { Combobox, FormControl, Icon } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
@@ -10,12 +10,16 @@ import type { PromptTriggerSelectProps } from 'features/prompt/types';
import { t } from 'i18next';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiLinkSimple } from 'react-icons/pi';
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { useEmbeddingModels, useLoRAModels } from 'services/api/hooks/modelsByType';
import { isNonRefinerMainModelConfig } from 'services/api/types';
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
type RelatedEmbedding = ComboboxOption & { starred?: boolean };
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
const { t } = useTranslation();
@@ -27,6 +31,27 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
const [loraModels, { isLoading: isLoadingLoRAs }] = useLoRAModels();
const [tiModels, { isLoading: isLoadingTIs }] = useEmbeddingModels();
// Get related model keys for current selected models
const selectedModelKeys = useMemo(() => {
const keys: string[] = [];
if (mainModel) {
keys.push(mainModel.key);
}
for (const { model } of addedLoRAs) {
keys.push(model.key);
}
return keys;
}, [mainModel, addedLoRAs]);
const { relatedModelKeys } = useGetRelatedModelIdsBatchQuery(selectedModelKeys, {
selectFromResult: ({ data }) => {
if (!data) {
return { relatedModelKeys: [] };
}
return { relatedModelKeys: data };
},
});
const _onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v) {
@@ -62,9 +87,25 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
}
if (tiModels) {
const embeddingOptions = tiModels
// Create embedding options with starred property for related models
const embeddingOptions: RelatedEmbedding[] = tiModels
.filter((ti) => ti.base === mainModelConfig?.base)
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
.map((model) => ({
label: model.name,
value: `<${model.name}>`,
starred: relatedModelKeys.includes(model.key),
}));
// Sort so related embeddings come first
embeddingOptions.sort((a, b) => {
if (a.starred && !b.starred) {
return -1;
}
if (!a.starred && b.starred) {
return 1;
}
return 0;
});
if (embeddingOptions.length > 0) {
_options.push({
@@ -85,7 +126,20 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
}
return _options;
}, [tiModels, loraModels, mainModelConfig, t, addedLoRAs]);
}, [tiModels, loraModels, mainModelConfig, t, addedLoRAs, relatedModelKeys]);
const formatOptionLabel = useCallback((option: ComboboxOption) => {
const embeddingOption = option as RelatedEmbedding;
if (embeddingOption.starred) {
return (
<div style={{ display: 'flex', alignItems: 'center', gap: '8px' }}>
<Icon as={PiLinkSimple} color="invokeYellow.500" boxSize={3} />
{option.label}
</div>
);
}
return option.label;
}, []);
return (
<FormControl>
@@ -104,6 +158,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
onMenuClose={onClose}
data-testid="add-prompt-trigger"
sx={selectStyles}
formatOptionLabel={formatOptionLabel}
/>
</FormControl>
);