mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 12:35:01 -05:00
Enhance LoRA picker to default to current base model architecture
Co-authored-by: kent <kent@invoke.ai> Enhance LoRA picker to default filter by current base model architecture Co-authored-by: kent <kent@invoke.ai>
This commit is contained in:
committed by
Kent Keirsey
parent
1320a2c5f8
commit
571d286506
@@ -198,6 +198,10 @@ type PickerProps<T extends object> = {
|
||||
* Whether the picker should be searchable. If true, renders a search input.
|
||||
*/
|
||||
searchable?: boolean;
|
||||
/**
|
||||
* Initial state for group toggles. If provided, groups will start with these states instead of all being disabled.
|
||||
*/
|
||||
initialGroupStates?: Record<string, boolean>;
|
||||
};
|
||||
|
||||
export type PickerContextState<T extends object> = {
|
||||
@@ -312,7 +316,10 @@ const flattenOptions = <T extends object>(options: OptionOrGroup<T>[]): T[] => {
|
||||
|
||||
type GroupStatusMap = Record<string, boolean>;
|
||||
|
||||
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
|
||||
const useTogglableGroups = <T extends object>(
|
||||
options: OptionOrGroup<T>[],
|
||||
initialGroupStates?: Record<string, boolean>
|
||||
) => {
|
||||
const groupsWithOptions = useMemo(() => {
|
||||
const ids: string[] = [];
|
||||
for (const optionOrGroup of options) {
|
||||
@@ -332,14 +339,16 @@ const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
|
||||
const groupStatusMap = $groupStatusMap.get();
|
||||
const newMap: GroupStatusMap = {};
|
||||
for (const id of groupsWithOptions) {
|
||||
if (newMap[id] === undefined) {
|
||||
newMap[id] = false;
|
||||
if (initialGroupStates && initialGroupStates[id] !== undefined) {
|
||||
newMap[id] = initialGroupStates[id];
|
||||
} else if (groupStatusMap[id] !== undefined) {
|
||||
newMap[id] = groupStatusMap[id];
|
||||
} else {
|
||||
newMap[id] = false;
|
||||
}
|
||||
}
|
||||
$groupStatusMap.set(newMap);
|
||||
}, [groupsWithOptions, $groupStatusMap]);
|
||||
}, [groupsWithOptions, $groupStatusMap, initialGroupStates]);
|
||||
|
||||
const toggleGroup = useCallback(
|
||||
(idToToggle: string) => {
|
||||
@@ -511,10 +520,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
|
||||
OptionComponent = DefaultOptionComponent,
|
||||
NextToSearchBar,
|
||||
searchable,
|
||||
initialGroupStates,
|
||||
} = props;
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(optionsOrGroups);
|
||||
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(
|
||||
optionsOrGroups,
|
||||
initialGroupStates
|
||||
);
|
||||
const $activeOptionId = useAtom(getFirstOptionId(optionsOrGroups, getOptionId));
|
||||
const $compactView = useAtom(true);
|
||||
const $optionsOrGroups = useAtom(optionsOrGroups);
|
||||
|
||||
@@ -6,6 +6,7 @@ import { useRelatedGroupedModelCombobox } from 'common/hooks/useRelatedGroupedMo
|
||||
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { ModelPicker } from 'features/parameters/components/ModelPicker';
|
||||
import { API_BASE_MODELS } from 'features/parameters/types/constants';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLoRAModels } from 'services/api/hooks/modelsByType';
|
||||
@@ -58,6 +59,19 @@ const LoRASelect = () => {
|
||||
return t('models.addLora');
|
||||
}, [isLoading, options.length, t]);
|
||||
|
||||
// Calculate initial group states to default to the current base model architecture
|
||||
const initialGroupStates = useMemo(() => {
|
||||
if (!currentBaseModel) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Determine the group ID for the current base model
|
||||
const groupId = API_BASE_MODELS.includes(currentBaseModel) ? 'api' : currentBaseModel;
|
||||
|
||||
// Return a map with only the current base model group enabled
|
||||
return { [groupId]: true };
|
||||
}, [currentBaseModel]);
|
||||
|
||||
return (
|
||||
<FormControl gap={2}>
|
||||
<InformationalPopover feature="lora">
|
||||
@@ -72,6 +86,7 @@ const LoRASelect = () => {
|
||||
placeholder={placeholder}
|
||||
getIsOptionDisabled={getIsDisabled}
|
||||
noOptionsText={t('models.noLoRAsInstalled')}
|
||||
initialGroupStates={initialGroupStates}
|
||||
/>
|
||||
</FormControl>
|
||||
);
|
||||
|
||||
@@ -125,6 +125,7 @@ export const ModelPicker = typedMemo(
|
||||
isInvalid,
|
||||
className,
|
||||
noOptionsText,
|
||||
initialGroupStates,
|
||||
}: {
|
||||
modelConfigs: T[];
|
||||
selectedModelConfig: T | undefined;
|
||||
@@ -137,6 +138,7 @@ export const ModelPicker = typedMemo(
|
||||
isInvalid?: boolean;
|
||||
className?: string;
|
||||
noOptionsText?: string;
|
||||
initialGroupStates?: Record<string, boolean>;
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const options = useMemo<T[] | Group<T>[]>(() => {
|
||||
@@ -244,6 +246,7 @@ export const ModelPicker = typedMemo(
|
||||
NextToSearchBar={<NavigateToModelManagerButton />}
|
||||
getIsOptionDisabled={getIsOptionDisabled}
|
||||
searchable
|
||||
initialGroupStates={initialGroupStates}
|
||||
/>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
|
||||
Reference in New Issue
Block a user