fix(ui): fix refiner missing from model manager

Rolled back the earlier split of the refiner model query.

Now, when you use `useGetMainModelsQuery()`, you must provide it an array of base model types.

They are provided as constants for simplicity:
- ALL_BASE_MODELS
- NON_REFINER_BASE_MODELS
- REFINER_BASE_MODELS

Opted to just use args for the hook instead of wrapping the hook in another hook, we can tidy this up later if desired.
This commit is contained in:
psychedelicious
2023-07-26 11:04:02 +10:00
parent 6fa244a343
commit cbcd416b70
19 changed files with 72 additions and 75 deletions

View File

@@ -14,6 +14,7 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
@@ -27,7 +28,9 @@ const ModelInputFieldComponent = (
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const { data: mainModels, isLoading } = useGetMainModelsQuery(
NON_REFINER_BASE_MODELS
);
const data = useMemo(() => {
if (!mainModels) {

View File

@@ -13,7 +13,8 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
const RefinerModelInputFieldComponent = (
@@ -27,7 +28,8 @@ const RefinerModelInputFieldComponent = (
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
const { data: refinerModels, isLoading } =
useGetMainModelsQuery(REFINER_BASE_MODELS);
const data = useMemo(() => {
if (!refinerModels) {

View File

@@ -14,6 +14,7 @@ import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
@@ -29,8 +30,10 @@ const ParamMainModelSelect = () => {
const { model } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: mainModels, isLoading } = useGetMainModelsQuery(
NON_REFINER_BASE_MODELS
);
const data = useMemo(() => {
if (!mainModels) {

View File

@@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],

View File

@@ -4,10 +4,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],

View File

@@ -11,7 +11,8 @@ import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
const selector = createSelector(
stateSelector,
@@ -24,7 +25,8 @@ const ParamSDXLRefinerModelSelect = () => {
const { model } = useAppSelector(selector);
const { data: refinerModels, isLoading } = useGetSDXLRefinerModelsQuery();
const { data: refinerModels, isLoading } =
useGetMainModelsQuery(REFINER_BASE_MODELS);
const data = useMemo(() => {
if (!refinerModels) {

View File

@@ -7,11 +7,11 @@ import {
SCHEDULER_LABEL_MAP,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
stateSelector,

View File

@@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],

View File

@@ -4,10 +4,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],

View File

@@ -1,9 +1,9 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { useIsRefinerAvailable } from 'features/sdxl/hooks/useIsRefinerAvailable';
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
import { ChangeEvent } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
export default function ParamUseSDXLRefiner() {
const shouldUseSDXLRefiner = useAppSelector(

View File

@@ -1,11 +0,0 @@
import { useGetSDXLRefinerModelsQuery } from 'services/api/endpoints/models';
export const useIsRefinerAvailable = () => {
const { isRefinerAvailable } = useGetSDXLRefinerModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
isRefinerAvailable: data ? data.ids.length > 0 : false,
}),
});
return isRefinerAvailable;
};

View File

@@ -16,6 +16,7 @@ import {
useImportMainModelsMutation,
} from 'services/api/endpoints/models';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import { ALL_BASE_MODELS } from 'services/api/constants';
export default function FoundModelsList() {
const searchFolder = useAppSelector(
@@ -24,7 +25,7 @@ export default function FoundModelsList() {
const [nameFilter, setNameFilter] = useState<string>('');
// Get paths of models that are already installed
const { data: installedModels } = useGetMainModelsQuery();
const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
// Get all model paths from a given directory
const { foundModels, alreadyInstalled, filteredModels } =

View File

@@ -1,5 +1,4 @@
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
@@ -8,9 +7,11 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { pickBy } from 'lodash-es';
import { useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import {
useGetMainModelsQuery,
useMergeMainModelsMutation,
@@ -32,7 +33,7 @@ export default function MergeModelsPanel() {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery();
const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();

View File

@@ -8,10 +8,11 @@ import {
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
import { ALL_BASE_MODELS } from 'services/api/constants';
export default function ModelManagerPanel() {
const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(undefined, {
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),

View File

@@ -11,6 +11,7 @@ import {
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
type ModelListProps = {
selectedModelId: string | undefined;
@@ -26,13 +27,13 @@ const ModelList = (props: ModelListProps) => {
const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('images');
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
}),
});
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
}),