mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 11:44:55 -05:00
Added related model support
This commit is contained in:
committed by
psychedelicious
parent
c10865c7ef
commit
49ae66d94a
@@ -1,5 +1,5 @@
|
||||
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, Icon } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { GroupBase } from 'chakra-react-select';
|
||||
@@ -10,12 +10,16 @@ import type { PromptTriggerSelectProps } from 'features/prompt/types';
|
||||
import { t } from 'i18next';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiLinkSimple } from 'react-icons/pi';
|
||||
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
|
||||
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||
import { useEmbeddingModels, useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
|
||||
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
|
||||
|
||||
type RelatedEmbedding = ComboboxOption & { starred?: boolean };
|
||||
|
||||
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -27,6 +31,27 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
const [loraModels, { isLoading: isLoadingLoRAs }] = useLoRAModels();
|
||||
const [tiModels, { isLoading: isLoadingTIs }] = useEmbeddingModels();
|
||||
|
||||
// Get related model keys for current selected models
|
||||
const selectedModelKeys = useMemo(() => {
|
||||
const keys: string[] = [];
|
||||
if (mainModel) {
|
||||
keys.push(mainModel.key);
|
||||
}
|
||||
for (const { model } of addedLoRAs) {
|
||||
keys.push(model.key);
|
||||
}
|
||||
return keys;
|
||||
}, [mainModel, addedLoRAs]);
|
||||
|
||||
const { relatedModelKeys } = useGetRelatedModelIdsBatchQuery(selectedModelKeys, {
|
||||
selectFromResult: ({ data }) => {
|
||||
if (!data) {
|
||||
return { relatedModelKeys: [] };
|
||||
}
|
||||
return { relatedModelKeys: data };
|
||||
},
|
||||
});
|
||||
|
||||
const _onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!v) {
|
||||
@@ -62,9 +87,25 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
}
|
||||
|
||||
if (tiModels) {
|
||||
const embeddingOptions = tiModels
|
||||
// Create embedding options with starred property for related models
|
||||
const embeddingOptions: RelatedEmbedding[] = tiModels
|
||||
.filter((ti) => ti.base === mainModelConfig?.base)
|
||||
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
|
||||
.map((model) => ({
|
||||
label: model.name,
|
||||
value: `<${model.name}>`,
|
||||
starred: relatedModelKeys.includes(model.key),
|
||||
}));
|
||||
|
||||
// Sort so related embeddings come first
|
||||
embeddingOptions.sort((a, b) => {
|
||||
if (a.starred && !b.starred) {
|
||||
return -1;
|
||||
}
|
||||
if (!a.starred && b.starred) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
});
|
||||
|
||||
if (embeddingOptions.length > 0) {
|
||||
_options.push({
|
||||
@@ -85,7 +126,20 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
}
|
||||
|
||||
return _options;
|
||||
}, [tiModels, loraModels, mainModelConfig, t, addedLoRAs]);
|
||||
}, [tiModels, loraModels, mainModelConfig, t, addedLoRAs, relatedModelKeys]);
|
||||
|
||||
const formatOptionLabel = useCallback((option: ComboboxOption) => {
|
||||
const embeddingOption = option as RelatedEmbedding;
|
||||
if (embeddingOption.starred) {
|
||||
return (
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: '8px' }}>
|
||||
<Icon as={PiLinkSimple} color="invokeYellow.500" boxSize={3} />
|
||||
{option.label}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return option.label;
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
@@ -104,6 +158,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
|
||||
onMenuClose={onClose}
|
||||
data-testid="add-prompt-trigger"
|
||||
sx={selectStyles}
|
||||
formatOptionLabel={formatOptionLabel}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user