diff --git a/invokeai/frontend/web/src/common/components/InvSelect/useGroupedModelInvSelect.ts b/invokeai/frontend/web/src/common/components/InvSelect/useGroupedModelInvSelect.ts index 3cf8445c85..abe14c04f6 100644 --- a/invokeai/frontend/web/src/common/components/InvSelect/useGroupedModelInvSelect.ts +++ b/invokeai/frontend/web/src/common/components/InvSelect/useGroupedModelInvSelect.ts @@ -34,7 +34,7 @@ export const useGroupedModelInvSelect = ( ); const { modelEntities, selectedModel, getIsDisabled, onChange, isLoading } = arg; - const options = useMemo(() => { + const options = useMemo[]>(() => { if (!modelEntities) { return []; } diff --git a/invokeai/frontend/web/src/common/components/InvSelect/useModelInvSelect.ts b/invokeai/frontend/web/src/common/components/InvSelect/useModelInvSelect.ts new file mode 100644 index 0000000000..186d754ef6 --- /dev/null +++ b/invokeai/frontend/web/src/common/components/InvSelect/useModelInvSelect.ts @@ -0,0 +1,91 @@ +import type { EntityState } from '@reduxjs/toolkit'; +import { map } from 'lodash-es'; +import { useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import type { AnyModelConfigEntity } from 'services/api/endpoints/models'; +import { getModelId } from 'services/api/endpoints/models'; + +import type { InvSelectOnChange, InvSelectOption } from './types'; + +type UseModelInvSelectArg = { + modelEntities: EntityState | undefined; + selectedModel?: Pick | null; + onChange: (value: T | null) => void; + getIsDisabled?: (model: T) => boolean; + optionsFilter?: (model: T) => boolean; + isLoading?: boolean; +}; + +type UseModelInvSelectReturn = { + value: InvSelectOption | undefined | null; + options: InvSelectOption[]; + onChange: InvSelectOnChange; + placeholder: string; + noOptionsMessage: () => string; +}; + +export const useModelInvSelect = ( + arg: UseModelInvSelectArg +): UseModelInvSelectReturn => { + const { t } = useTranslation(); + const { + modelEntities, + selectedModel, + getIsDisabled, + onChange, + isLoading, + optionsFilter = () => true, + } = arg; + const options = useMemo(() => { + if (!modelEntities) { + return []; + } + return map(modelEntities.entities) + .filter(optionsFilter) + .map((model) => ({ + label: model.model_name, + value: model.id, + isDisabled: getIsDisabled ? getIsDisabled(model) : false, + })); + }, [optionsFilter, getIsDisabled, modelEntities]); + + const value = useMemo( + () => + options.find((m) => + selectedModel ? m.value === getModelId(selectedModel) : false + ), + [options, selectedModel] + ); + + const _onChange = useCallback( + (v) => { + if (!v) { + onChange(null); + return; + } + const model = modelEntities?.entities[v.value]; + if (!model) { + onChange(null); + return; + } + onChange(model); + }, + [modelEntities?.entities, onChange] + ); + + const placeholder = useMemo(() => { + if (isLoading) { + return t('common.loading'); + } + + if (options.length === 0) { + return t('models.noModelsAvailable'); + } + + return t('models.selectModel'); + }, [isLoading, options, t]); + + const noOptionsMessage = useCallback(() => t('models.noMatchingModels'), [t]); + + return { options, value, onChange: _onChange, placeholder, noOptionsMessage }; +}; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx index 06c3f980ed..e71716db89 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InvControl } from 'common/components/InvControl/InvControl'; import { InvSelect } from 'common/components/InvSelect/InvSelect'; -import { useGroupedModelInvSelect } from 'common/components/InvSelect/useGroupedModelInvSelect'; +import { useModelInvSelect } from 'common/components/InvSelect/useModelInvSelect'; import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -15,6 +15,9 @@ const selector = createMemoizedSelector(stateSelector, (state) => ({ model: state.sdxl.refinerModel, })); +const optionsFilter = (model: MainModelConfigEntity) => + model.base_model === 'sdxl-refiner'; + const ParamSDXLRefinerModelSelect = () => { const dispatch = useAppDispatch(); const { model } = useAppSelector(selector); @@ -37,11 +40,12 @@ const ParamSDXLRefinerModelSelect = () => { [dispatch] ); const { options, value, onChange, placeholder, noOptionsMessage } = - useGroupedModelInvSelect({ + useModelInvSelect({ modelEntities: data, onChange: _onChange, selectedModel: model, isLoading, + optionsFilter, }); return (