mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-06 08:24:54 -05:00
386 lines
14 KiB
TypeScript
386 lines
14 KiB
TypeScript
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
|
import { createEntityAdapter } from '@reduxjs/toolkit';
|
|
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
|
import queryString from 'query-string';
|
|
import type { operations, paths } from 'services/api/schema';
|
|
import type {
|
|
AnyModelConfig,
|
|
BaseModelType,
|
|
ControlNetModelConfig,
|
|
IPAdapterModelConfig,
|
|
LoRAModelConfig,
|
|
MainModelConfig,
|
|
T2IAdapterModelConfig,
|
|
TextualInversionModelConfig,
|
|
VAEModelConfig,
|
|
} from 'services/api/types';
|
|
|
|
import type { ApiTagDescription, tagTypes } from '..';
|
|
import { api, buildV2Url, LIST_TAG } from '..';
|
|
|
|
export type UpdateModelArg = {
|
|
key: paths['/api/v2/models/i/{key}']['patch']['parameters']['path']['key'];
|
|
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
|
};
|
|
|
|
export type UpdateModelImageArg = {
|
|
key: paths['/api/v2/models/i/{key}/image']['patch']['parameters']['path']['key'];
|
|
image: paths['/api/v2/models/i/{key}/image']['patch']['requestBody']['content']['multipart/form-data'];
|
|
};
|
|
|
|
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
|
type UpdateModelImageResponse =
|
|
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
|
|
|
|
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
|
|
|
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
|
|
|
type DeleteModelArg = {
|
|
key: string;
|
|
};
|
|
type DeleteModelResponse = void;
|
|
type DeleteModelImageResponse = void;
|
|
|
|
type ConvertMainModelResponse =
|
|
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
|
|
|
type InstallModelArg = {
|
|
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
|
|
};
|
|
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
|
|
|
|
type ListModelInstallsResponse =
|
|
paths['/api/v2/models/install']['get']['responses']['200']['content']['application/json'];
|
|
|
|
type CancelModelInstallResponse =
|
|
paths['/api/v2/models/install/{id}']['delete']['responses']['201']['content']['application/json'];
|
|
|
|
type PruneCompletedModelInstallsResponse =
|
|
paths['/api/v2/models/install']['delete']['responses']['200']['content']['application/json'];
|
|
|
|
export type ScanFolderResponse =
|
|
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
|
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
|
|
|
|
type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query'];
|
|
|
|
const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
|
selectId: (entity) => entity.key,
|
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
});
|
|
export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
const loraModelsAdapter = createEntityAdapter<LoRAModelConfig, string>({
|
|
selectId: (entity) => entity.key,
|
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
});
|
|
export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
const controlNetModelsAdapter = createEntityAdapter<ControlNetModelConfig, string>({
|
|
selectId: (entity) => entity.key,
|
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
});
|
|
export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
const ipAdapterModelsAdapter = createEntityAdapter<IPAdapterModelConfig, string>({
|
|
selectId: (entity) => entity.key,
|
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
});
|
|
export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
const t2iAdapterModelsAdapter = createEntityAdapter<T2IAdapterModelConfig, string>({
|
|
selectId: (entity) => entity.key,
|
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
});
|
|
export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
const textualInversionModelsAdapter = createEntityAdapter<TextualInversionModelConfig, string>({
|
|
selectId: (entity) => entity.key,
|
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
});
|
|
export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors(
|
|
undefined,
|
|
getSelectorsOptions
|
|
);
|
|
const vaeModelsAdapter = createEntityAdapter<VAEModelConfig, string>({
|
|
selectId: (entity) => entity.key,
|
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
});
|
|
export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
|
|
const anyModelConfigAdapter = createEntityAdapter<AnyModelConfig, string>({
|
|
selectId: (entity) => entity.key,
|
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
});
|
|
const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions);
|
|
|
|
const buildProvidesTags =
|
|
<TEntity extends AnyModelConfig>(tagType: (typeof tagTypes)[number]) =>
|
|
(result: EntityState<TEntity, string> | undefined) => {
|
|
const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model'];
|
|
if (result) {
|
|
tags.push(
|
|
...result.ids.map((id) => ({
|
|
type: tagType,
|
|
id,
|
|
}))
|
|
);
|
|
}
|
|
|
|
return tags;
|
|
};
|
|
|
|
const buildTransformResponse =
|
|
<T extends AnyModelConfig>(adapter: EntityAdapter<T, string>) =>
|
|
(response: { models: T[] }) => {
|
|
return adapter.setAll(adapter.getInitialState(), response.models);
|
|
};
|
|
|
|
/**
|
|
* Builds an endpoint URL for the models router
|
|
* @example
|
|
* buildModelsUrl('some-path')
|
|
* // '/api/v1/models/some-path'
|
|
*/
|
|
export const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
|
|
|
export const modelsApi = api.injectEndpoints({
|
|
endpoints: (build) => ({
|
|
updateModel: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
|
query: ({ key, body }) => {
|
|
return {
|
|
url: buildModelsUrl(`i/${key}`),
|
|
method: 'PATCH',
|
|
body: body,
|
|
};
|
|
},
|
|
invalidatesTags: ['Model'],
|
|
}),
|
|
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
|
|
query: ({ key, image }) => {
|
|
const formData = new FormData();
|
|
formData.append('image', image.image);
|
|
return {
|
|
url: buildModelsUrl(`i/${key}/image`),
|
|
method: 'PATCH',
|
|
body: formData,
|
|
};
|
|
},
|
|
}),
|
|
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
|
query: ({ source }) => {
|
|
return {
|
|
url: buildModelsUrl('install'),
|
|
params: { source },
|
|
method: 'POST',
|
|
};
|
|
},
|
|
invalidatesTags: ['Model', 'ModelInstalls'],
|
|
}),
|
|
deleteModels: build.mutation<DeleteModelResponse, DeleteModelArg>({
|
|
query: ({ key }) => {
|
|
return {
|
|
url: buildModelsUrl(`i/${key}`),
|
|
method: 'DELETE',
|
|
};
|
|
},
|
|
invalidatesTags: ['Model'],
|
|
}),
|
|
deleteModelImage: build.mutation<DeleteModelImageResponse, string>({
|
|
query: (key) => {
|
|
return {
|
|
url: buildModelsUrl(`i/${key}/image`),
|
|
method: 'DELETE',
|
|
};
|
|
},
|
|
}),
|
|
getModelImage: build.query<string, string>({
|
|
query: (key) => buildModelsUrl(`i/${key}/image`)
|
|
}),
|
|
convertModel: build.mutation<ConvertMainModelResponse, string>({
|
|
query: (key) => {
|
|
return {
|
|
url: buildModelsUrl(`convert/${key}`),
|
|
method: 'PUT',
|
|
};
|
|
},
|
|
invalidatesTags: ['ModelConfig'],
|
|
}),
|
|
getModelConfig: build.query<GetModelConfigResponse, string>({
|
|
query: (key) => buildModelsUrl(`i/${key}`),
|
|
providesTags: (result) => {
|
|
const tags: ApiTagDescription[] = ['Model'];
|
|
|
|
if (result) {
|
|
tags.push({ type: 'ModelConfig', id: result.key });
|
|
}
|
|
|
|
return tags;
|
|
},
|
|
}),
|
|
getModelConfigByAttrs: build.query<AnyModelConfig, GetByAttrsArg>({
|
|
query: (arg) => buildModelsUrl(`get_by_attrs?${queryString.stringify(arg)}`),
|
|
providesTags: (result) => {
|
|
const tags: ApiTagDescription[] = ['Model'];
|
|
|
|
if (result) {
|
|
tags.push({ type: 'ModelConfig', id: result.key });
|
|
}
|
|
|
|
return tags;
|
|
},
|
|
serializeQueryArgs: ({ queryArgs }) => `${queryArgs.name}.${queryArgs.base}.${queryArgs.type}`,
|
|
}),
|
|
syncModels: build.mutation<void, void>({
|
|
query: () => {
|
|
return {
|
|
url: buildModelsUrl('sync'),
|
|
method: 'PATCH',
|
|
};
|
|
},
|
|
invalidatesTags: ['Model'],
|
|
}),
|
|
scanFolder: build.query<ScanFolderResponse, ScanFolderArg>({
|
|
query: (arg) => {
|
|
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
|
|
return {
|
|
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
|
};
|
|
},
|
|
}),
|
|
listModelInstalls: build.query<ListModelInstallsResponse, void>({
|
|
query: () => {
|
|
return {
|
|
url: buildModelsUrl('install'),
|
|
};
|
|
},
|
|
providesTags: ['ModelInstalls'],
|
|
}),
|
|
cancelModelInstall: build.mutation<CancelModelInstallResponse, number>({
|
|
query: (id) => {
|
|
return {
|
|
url: buildModelsUrl(`install/${id}`),
|
|
method: 'DELETE',
|
|
};
|
|
},
|
|
invalidatesTags: ['ModelInstalls'],
|
|
}),
|
|
pruneCompletedModelInstalls: build.mutation<PruneCompletedModelInstallsResponse, void>({
|
|
query: () => {
|
|
return {
|
|
url: buildModelsUrl('install'),
|
|
method: 'DELETE',
|
|
};
|
|
},
|
|
invalidatesTags: ['ModelInstalls'],
|
|
}),
|
|
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
|
query: (base_models) => {
|
|
const params: ListModelsArg = {
|
|
model_type: 'main',
|
|
base_models,
|
|
};
|
|
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
|
return buildModelsUrl(`?${query}`);
|
|
},
|
|
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
|
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
queryFulfilled.then(({ data }) => {
|
|
upsertModelConfigs(data, dispatch);
|
|
});
|
|
},
|
|
}),
|
|
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
|
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
|
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
|
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
queryFulfilled.then(({ data }) => {
|
|
upsertModelConfigs(data, dispatch);
|
|
});
|
|
},
|
|
}),
|
|
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
|
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
|
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
|
|
transformResponse: buildTransformResponse<ControlNetModelConfig>(controlNetModelsAdapter),
|
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
queryFulfilled.then(({ data }) => {
|
|
upsertModelConfigs(data, dispatch);
|
|
});
|
|
},
|
|
}),
|
|
getIPAdapterModels: build.query<EntityState<IPAdapterModelConfig, string>, void>({
|
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }),
|
|
providesTags: buildProvidesTags<IPAdapterModelConfig>('IPAdapterModel'),
|
|
transformResponse: buildTransformResponse<IPAdapterModelConfig>(ipAdapterModelsAdapter),
|
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
queryFulfilled.then(({ data }) => {
|
|
upsertModelConfigs(data, dispatch);
|
|
});
|
|
},
|
|
}),
|
|
getT2IAdapterModels: build.query<EntityState<T2IAdapterModelConfig, string>, void>({
|
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }),
|
|
providesTags: buildProvidesTags<T2IAdapterModelConfig>('T2IAdapterModel'),
|
|
transformResponse: buildTransformResponse<T2IAdapterModelConfig>(t2iAdapterModelsAdapter),
|
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
queryFulfilled.then(({ data }) => {
|
|
upsertModelConfigs(data, dispatch);
|
|
});
|
|
},
|
|
}),
|
|
getVaeModels: build.query<EntityState<VAEModelConfig, string>, void>({
|
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }),
|
|
providesTags: buildProvidesTags<VAEModelConfig>('VaeModel'),
|
|
transformResponse: buildTransformResponse<VAEModelConfig>(vaeModelsAdapter),
|
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
queryFulfilled.then(({ data }) => {
|
|
upsertModelConfigs(data, dispatch);
|
|
});
|
|
},
|
|
}),
|
|
getTextualInversionModels: build.query<EntityState<TextualInversionModelConfig, string>, void>({
|
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }),
|
|
providesTags: buildProvidesTags<TextualInversionModelConfig>('TextualInversionModel'),
|
|
transformResponse: buildTransformResponse<TextualInversionModelConfig>(textualInversionModelsAdapter),
|
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
queryFulfilled.then(({ data }) => {
|
|
upsertModelConfigs(data, dispatch);
|
|
});
|
|
},
|
|
}),
|
|
}),
|
|
});
|
|
|
|
export const {
|
|
useGetModelConfigQuery,
|
|
useGetMainModelsQuery,
|
|
useGetControlNetModelsQuery,
|
|
useGetIPAdapterModelsQuery,
|
|
useGetT2IAdapterModelsQuery,
|
|
useGetLoRAModelsQuery,
|
|
useGetTextualInversionModelsQuery,
|
|
useGetVaeModelsQuery,
|
|
useDeleteModelsMutation,
|
|
useDeleteModelImageMutation,
|
|
useUpdateModelMutation,
|
|
useUpdateModelImageMutation,
|
|
useInstallModelMutation,
|
|
useConvertModelMutation,
|
|
useSyncModelsMutation,
|
|
useLazyScanFolderQuery,
|
|
useListModelInstallsQuery,
|
|
useCancelModelInstallMutation,
|
|
usePruneCompletedModelInstallsMutation,
|
|
} = modelsApi;
|
|
|
|
const upsertModelConfigs = (
|
|
modelConfigs: EntityState<AnyModelConfig, string>,
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
dispatch: ThunkDispatch<any, any, UnknownAction>
|
|
) => {
|
|
anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => {
|
|
dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig));
|
|
const { base, name, type } = modelConfig;
|
|
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig));
|
|
});
|
|
};
|