mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-16 12:05:25 -05:00
fix(ui): fix refiner missing from model manager
Rolled back the earlier split of the refiner model query. Now, when you use `useGetMainModelsQuery()`, you must provide it an array of base model types. They are provided as constants for simplicity: - ALL_BASE_MODELS - NON_REFINER_BASE_MODELS - REFINER_BASE_MODELS Opted to just use args for the hook instead of wrapping the hook in another hook, we can tidy this up later if desired.
This commit is contained in:
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { BaseModelType } from './types';
|
||||
|
||||
export const ALL_BASE_MODELS: BaseModelType[] = [
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sdxl',
|
||||
'sdxl-refiner',
|
||||
];
|
||||
|
||||
export const NON_REFINER_BASE_MODELS: BaseModelType[] = [
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sdxl',
|
||||
];
|
||||
|
||||
export const REFINER_BASE_MODELS: BaseModelType[] = ['sdxl-refiner'];
|
||||
@@ -107,9 +107,6 @@ type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
const sdxlRefinerModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
const loraModelsAdapter = createEntityAdapter<LoRAModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
@@ -147,11 +144,14 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
||||
query: () => {
|
||||
getMainModels: build.query<
|
||||
EntityState<MainModelConfigEntity>,
|
||||
BaseModelType[]
|
||||
>({
|
||||
query: (base_models) => {
|
||||
const params = {
|
||||
model_type: 'main',
|
||||
base_models: ['sd-1', 'sd-2', 'sdxl'],
|
||||
base_models,
|
||||
};
|
||||
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
@@ -187,43 +187,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
);
|
||||
},
|
||||
}),
|
||||
getSDXLRefinerModels: build.query<EntityState<MainModelConfigEntity>, void>(
|
||||
{
|
||||
query: () => ({
|
||||
url: 'models/',
|
||||
params: { model_type: 'main', base_models: ['sdxl-refiner'] },
|
||||
}),
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'SDXLRefinerModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (
|
||||
response: { models: MainModelConfig[] },
|
||||
meta,
|
||||
arg
|
||||
) => {
|
||||
const entities = createModelEntities<MainModelConfigEntity>(
|
||||
response.models
|
||||
);
|
||||
return sdxlRefinerModelsAdapter.setAll(
|
||||
sdxlRefinerModelsAdapter.getInitialState(),
|
||||
entities
|
||||
);
|
||||
},
|
||||
}
|
||||
),
|
||||
updateMainModels: build.mutation<
|
||||
UpdateMainModelResponse,
|
||||
UpdateMainModelArg
|
||||
@@ -494,7 +457,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
|
||||
export const {
|
||||
useGetMainModelsQuery,
|
||||
useGetSDXLRefinerModelsQuery,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export const useIsRefinerAvailable = () => {
|
||||
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
||||
}),
|
||||
});
|
||||
|
||||
return isRefinerAvailable;
|
||||
};
|
||||
Reference in New Issue
Block a user