import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import queryString from 'query-string'; import type { operations, paths } from 'services/api/schema'; import type { AnyModelConfig, BaseModelType, ControlNetModelConfig, IPAdapterModelConfig, LoRAModelConfig, MainModelConfig, T2IAdapterModelConfig, TextualInversionModelConfig, VAEModelConfig, } from 'services/api/types'; import type { ApiTagDescription, tagTypes } from '..'; import { api, buildV2Url, LIST_TAG } from '..'; export type UpdateModelArg = { key: paths['/api/v2/models/i/{key}']['patch']['parameters']['path']['key']; body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json']; }; export type UpdateModelImageArg = { key: paths['/api/v2/models/i/{key}/image']['patch']['parameters']['path']['key']; image: paths['/api/v2/models/i/{key}/image']['patch']['formData']['content']['multipart/form-data']; }; type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; type UpdateModelImageResponse = paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json']; type GetModelImageResponse = paths['/api/v2/models/i/{key}/image']['get']['responses']['200']['content']['application/json']; type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json']; type ListModelsArg = NonNullable; type DeleteModelArg = { key: string; }; type DeleteModelResponse = void; type ConvertMainModelResponse = paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json']; type InstallModelArg = { source: paths['/api/v2/models/install']['post']['parameters']['query']['source']; }; type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json']; type ListModelInstallsResponse = paths['/api/v2/models/install']['get']['responses']['200']['content']['application/json']; type CancelModelInstallResponse = paths['/api/v2/models/install/{id}']['delete']['responses']['201']['content']['application/json']; type PruneCompletedModelInstallsResponse = paths['/api/v2/models/install']['delete']['responses']['200']['content']['application/json']; export type ScanFolderResponse = paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json']; type ScanFolderArg = operations['scan_for_models']['parameters']['query']; type GetByAttrsArg = operations['get_model_records_by_attrs']['parameters']['query']; const mainModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const mainModelsAdapterSelectors = mainModelsAdapter.getSelectors(undefined, getSelectorsOptions); const loraModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const loraModelsAdapterSelectors = loraModelsAdapter.getSelectors(undefined, getSelectorsOptions); const controlNetModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const controlNetModelsAdapterSelectors = controlNetModelsAdapter.getSelectors(undefined, getSelectorsOptions); const ipAdapterModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const ipAdapterModelsAdapterSelectors = ipAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); const t2iAdapterModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const t2iAdapterModelsAdapterSelectors = t2iAdapterModelsAdapter.getSelectors(undefined, getSelectorsOptions); const textualInversionModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const textualInversionModelsAdapterSelectors = textualInversionModelsAdapter.getSelectors( undefined, getSelectorsOptions ); const vaeModelsAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); export const vaeModelsAdapterSelectors = vaeModelsAdapter.getSelectors(undefined, getSelectorsOptions); const anyModelConfigAdapter = createEntityAdapter({ selectId: (entity) => entity.key, sortComparer: (a, b) => a.name.localeCompare(b.name), }); const anyModelConfigAdapterSelectors = anyModelConfigAdapter.getSelectors(undefined, getSelectorsOptions); const buildProvidesTags = (tagType: (typeof tagTypes)[number]) => (result: EntityState | undefined) => { const tags: ApiTagDescription[] = [{ type: tagType, id: LIST_TAG }, 'Model']; if (result) { tags.push( ...result.ids.map((id) => ({ type: tagType, id, })) ); } return tags; }; const buildTransformResponse = (adapter: EntityAdapter) => (response: { models: T[] }) => { return adapter.setAll(adapter.getInitialState(), response.models); }; /** * Builds an endpoint URL for the models router * @example * buildModelsUrl('some-path') * // '/api/v1/models/some-path' */ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ updateModel: build.mutation({ query: ({ key, body }) => { return { url: buildModelsUrl(`i/${key}`), method: 'PATCH', body: body, }; }, invalidatesTags: ['Model'], }), getModelImage: build.query({ query: (key) => buildModelsUrl(`i/${key}/image`), }), updateModelImage: build.mutation({ query: ({ key, image }) => { const formData = new FormData(); formData.append('image', image); return { url: buildModelsUrl(`i/${key}/image`), method: 'PATCH', body: formData, }; }, invalidatesTags: ['Model'], }), installModel: build.mutation({ query: ({ source }) => { return { url: buildModelsUrl('install'), params: { source }, method: 'POST', }; }, invalidatesTags: ['Model', 'ModelInstalls'], }), deleteModels: build.mutation({ query: ({ key }) => { return { url: buildModelsUrl(`i/${key}`), method: 'DELETE', }; }, invalidatesTags: ['Model'], }), convertModel: build.mutation({ query: (key) => { return { url: buildModelsUrl(`convert/${key}`), method: 'PUT', }; }, invalidatesTags: ['ModelConfig'], }), getModelConfig: build.query({ query: (key) => buildModelsUrl(`i/${key}`), providesTags: (result) => { const tags: ApiTagDescription[] = ['Model']; if (result) { tags.push({ type: 'ModelConfig', id: result.key }); } return tags; }, }), getModelConfigByAttrs: build.query({ query: (arg) => buildModelsUrl(`get_by_attrs?${queryString.stringify(arg)}`), providesTags: (result) => { const tags: ApiTagDescription[] = ['Model']; if (result) { tags.push({ type: 'ModelConfig', id: result.key }); } return tags; }, serializeQueryArgs: ({ queryArgs }) => `${queryArgs.name}.${queryArgs.base}.${queryArgs.type}`, }), syncModels: build.mutation({ query: () => { return { url: buildModelsUrl('sync'), method: 'PATCH', }; }, invalidatesTags: ['Model'], }), scanFolder: build.query({ query: (arg) => { const folderQueryStr = arg ? queryString.stringify(arg, {}) : ''; return { url: buildModelsUrl(`scan_folder?${folderQueryStr}`), }; }, }), listModelInstalls: build.query({ query: () => { return { url: buildModelsUrl('install'), }; }, providesTags: ['ModelInstalls'], }), cancelModelInstall: build.mutation({ query: (id) => { return { url: buildModelsUrl(`install/${id}`), method: 'DELETE', }; }, invalidatesTags: ['ModelInstalls'], }), pruneCompletedModelInstalls: build.mutation({ query: () => { return { url: buildModelsUrl('install'), method: 'DELETE', }; }, invalidatesTags: ['ModelInstalls'], }), getMainModels: build.query, BaseModelType[]>({ query: (base_models) => { const params: ListModelsArg = { model_type: 'main', base_models, }; const query = queryString.stringify(params, { arrayFormat: 'none' }); return buildModelsUrl(`?${query}`); }, providesTags: buildProvidesTags('MainModel'), transformResponse: buildTransformResponse(mainModelsAdapter), onQueryStarted: async (_, { dispatch, queryFulfilled }) => { queryFulfilled.then(({ data }) => { upsertModelConfigs(data, dispatch); }); }, }), getLoRAModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), providesTags: buildProvidesTags('LoRAModel'), transformResponse: buildTransformResponse(loraModelsAdapter), onQueryStarted: async (_, { dispatch, queryFulfilled }) => { queryFulfilled.then(({ data }) => { upsertModelConfigs(data, dispatch); }); }, }), getControlNetModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), providesTags: buildProvidesTags('ControlNetModel'), transformResponse: buildTransformResponse(controlNetModelsAdapter), onQueryStarted: async (_, { dispatch, queryFulfilled }) => { queryFulfilled.then(({ data }) => { upsertModelConfigs(data, dispatch); }); }, }), getIPAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'ip_adapter' } }), providesTags: buildProvidesTags('IPAdapterModel'), transformResponse: buildTransformResponse(ipAdapterModelsAdapter), onQueryStarted: async (_, { dispatch, queryFulfilled }) => { queryFulfilled.then(({ data }) => { upsertModelConfigs(data, dispatch); }); }, }), getT2IAdapterModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 't2i_adapter' } }), providesTags: buildProvidesTags('T2IAdapterModel'), transformResponse: buildTransformResponse(t2iAdapterModelsAdapter), onQueryStarted: async (_, { dispatch, queryFulfilled }) => { queryFulfilled.then(({ data }) => { upsertModelConfigs(data, dispatch); }); }, }), getVaeModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'vae' } }), providesTags: buildProvidesTags('VaeModel'), transformResponse: buildTransformResponse(vaeModelsAdapter), onQueryStarted: async (_, { dispatch, queryFulfilled }) => { queryFulfilled.then(({ data }) => { upsertModelConfigs(data, dispatch); }); }, }), getTextualInversionModels: build.query, void>({ query: () => ({ url: buildModelsUrl(), params: { model_type: 'embedding' } }), providesTags: buildProvidesTags('TextualInversionModel'), transformResponse: buildTransformResponse(textualInversionModelsAdapter), onQueryStarted: async (_, { dispatch, queryFulfilled }) => { queryFulfilled.then(({ data }) => { upsertModelConfigs(data, dispatch); }); }, }), }), }); export const { useGetModelConfigQuery, useGetMainModelsQuery, useGetControlNetModelsQuery, useGetIPAdapterModelsQuery, useGetT2IAdapterModelsQuery, useGetLoRAModelsQuery, useGetTextualInversionModelsQuery, useGetVaeModelsQuery, useDeleteModelsMutation, useUpdateModelsMutation, useGetModelImageQuery, useUpdateModelImageMutation, useInstallModelMutation, useConvertModelMutation, useSyncModelsMutation, useLazyScanFolderQuery, useListModelInstallsQuery, useCancelModelInstallMutation, usePruneCompletedModelInstallsMutation, } = modelsApi; const upsertModelConfigs = ( modelConfigs: EntityState, // eslint-disable-next-line @typescript-eslint/no-explicit-any dispatch: ThunkDispatch ) => { anyModelConfigAdapterSelectors.selectAll(modelConfigs).forEach((modelConfig) => { dispatch(modelsApi.util.upsertQueryData('getModelConfig', modelConfig.key, modelConfig)); const { base, name, type } = modelConfig; dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, modelConfig)); }); };