Merge branch 'main' into onnx-testing

This commit is contained in:
Brandon Rising
2023-07-18 22:56:41 -04:00
361 changed files with 13813 additions and 10110 deletions

View File

@@ -1,7 +1,6 @@
import { OffsetPaginatedResults_ImageDTO_ } from 'services/api/types';
import { api } from '..';
import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema';
import { imagesApi } from './images';
type ListBoardImagesArg =
paths['/api/v1/board_images/{board_id}']['get']['parameters']['path'] &
@@ -25,9 +24,26 @@ export const boardImagesApi = api.injectEndpoints({
>({
query: ({ board_id, offset, limit }) => ({
url: `board_images/${board_id}`,
method: 'DELETE',
body: { offset, limit },
method: 'GET',
}),
providesTags: (result, error, arg) => {
// any list of boardimages
const tags: ApiFullTagDescription[] = [
{ type: 'BoardImage', id: `${arg.board_id}_${LIST_TAG}` },
];
if (result) {
// and individual tags for each boardimage
tags.push(
...result.items.map(({ board_id, image_name }) => ({
type: 'BoardImage' as const,
id: `${board_id}_${image_name}`,
}))
);
}
return tags;
},
}),
/**
@@ -41,23 +57,9 @@ export const boardImagesApi = api.injectEndpoints({
body: { board_id, image_name },
}),
invalidatesTags: (result, error, arg) => [
{ type: 'BoardImage' },
{ type: 'Board', id: arg.board_id },
],
async onQueryStarted(
{ image_name, ...patch },
{ dispatch, queryFulfilled }
) {
const patchResult = dispatch(
imagesApi.util.updateQueryData('getImageDTO', image_name, (draft) => {
Object.assign(draft, patch);
})
);
try {
await queryFulfilled;
} catch {
patchResult.undo();
}
},
}),
removeImageFromBoard: build.mutation<void, RemoveImageFromBoardArg>({
@@ -67,23 +69,9 @@ export const boardImagesApi = api.injectEndpoints({
body: { board_id, image_name },
}),
invalidatesTags: (result, error, arg) => [
{ type: 'BoardImage' },
{ type: 'Board', id: arg.board_id },
],
async onQueryStarted(
{ image_name, ...patch },
{ dispatch, queryFulfilled }
) {
const patchResult = dispatch(
imagesApi.util.updateQueryData('getImageDTO', image_name, (draft) => {
Object.assign(draft, { board_id: null });
})
);
try {
await queryFulfilled;
} catch {
patchResult.undo();
}
},
}),
}),
});

View File

@@ -20,7 +20,7 @@ export const boardsApi = api.injectEndpoints({
query: (arg) => ({ url: 'boards/', params: arg }),
providesTags: (result, error, arg) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }];
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
if (result) {
// and individual tags for each board
@@ -43,7 +43,7 @@ export const boardsApi = api.injectEndpoints({
}),
providesTags: (result, error, arg) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }];
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
if (result) {
// and individual tags for each board
@@ -69,7 +69,7 @@ export const boardsApi = api.injectEndpoints({
method: 'POST',
params: { board_name },
}),
invalidatesTags: [{ id: 'Board', type: LIST_TAG }],
invalidatesTags: [{ type: 'Board', id: LIST_TAG }],
}),
updateBoard: build.mutation<BoardDTO, UpdateBoardArg>({
@@ -87,8 +87,15 @@ export const boardsApi = api.injectEndpoints({
invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }],
}),
deleteBoardAndImages: build.mutation<void, string>({
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE', params: { include_images: true } }),
invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }, { type: 'Image', id: LIST_TAG }],
query: (board_id) => ({
url: `boards/${board_id}`,
method: 'DELETE',
params: { include_images: true },
}),
invalidatesTags: (result, error, arg) => [
{ type: 'Board', id: arg },
{ type: 'Image', id: LIST_TAG },
],
}),
}),
});
@@ -99,5 +106,5 @@ export const {
useCreateBoardMutation,
useUpdateBoardMutation,
useDeleteBoardMutation,
useDeleteBoardAndImagesMutation
useDeleteBoardAndImagesMutation,
} = boardsApi;

View File

@@ -2,17 +2,31 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { cloneDeep } from 'lodash-es';
import {
AnyModelConfig,
BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig,
ConvertModelConfig,
DiffusersModelConfig,
ImportModelConfig,
LoRAModelConfig,
MainModelConfig,
OnnxModelConfig,
MergeModelConfig,
TextualInversionModelConfig,
VaeModelConfig,
} from 'services/api/types';
import queryString from 'query-string';
import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { operations, paths } from '../schema';
export type MainModelConfigEntity = MainModelConfig & { id: string };
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
id: string;
};
export type MainModelConfigEntity =
| DiffusersModelConfigEntity
| CheckpointModelConfigEntity;
export type OnnxModelConfigEntity = OnnxModelConfig & { id: string };
@@ -36,6 +50,61 @@ type AnyModelConfigEntity =
| TextualInversionModelConfigEntity
| VaeModelConfigEntity;
type UpdateMainModelArg = {
base_model: BaseModelType;
model_name: string;
body: MainModelConfig;
};
type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type DeleteMainModelArg = {
base_model: BaseModelType;
model_name: string;
};
type DeleteMainModelResponse = void;
type ConvertMainModelArg = {
base_model: BaseModelType;
model_name: string;
params: ConvertModelConfig;
};
type ConvertMainModelResponse =
paths['/api/v1/models/convert/{base_model}/{model_type}/{model_name}']['put']['responses']['200']['content']['application/json'];
type MergeMainModelArg = {
base_model: BaseModelType;
body: MergeModelConfig;
};
type MergeMainModelResponse =
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
type ImportMainModelArg = {
body: ImportModelConfig;
};
type ImportMainModelResponse =
paths['/api/v1/models/import']['post']['responses']['201']['content']['application/json'];
type AddMainModelArg = {
body: MainModelConfig;
};
type AddMainModelResponse =
paths['/api/v1/models/add']['post']['responses']['201']['content']['application/json'];
export type SearchFolderResponse =
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
type CheckpointConfigsResponse =
paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
@@ -116,7 +185,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'MainModel', type: LIST_TAG },
{ type: 'MainModel', id: LIST_TAG },
];
if (result) {
@@ -144,11 +213,82 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
updateMainModels: build.mutation<
UpdateMainModelResponse,
UpdateMainModelArg
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
importMainModels: build.mutation<
ImportMainModelResponse,
ImportMainModelArg
>({
query: ({ body }) => {
return {
url: `models/import`,
method: 'POST',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
query: ({ body }) => {
return {
url: `models/add`,
method: 'POST',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
deleteMainModels: build.mutation<
DeleteMainModelResponse,
DeleteMainModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
convertMainModels: build.mutation<
ConvertMainModelResponse,
ConvertMainModelArg
>({
query: ({ base_model, model_name, params }) => {
return {
url: `models/convert/${base_model}/main/${model_name}`,
method: 'PUT',
params: params,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => {
return {
url: `models/merge/${base_model}`,
method: 'PUT',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'LoRAModel', type: LIST_TAG },
{ type: 'LoRAModel', id: LIST_TAG },
];
if (result) {
@@ -183,7 +323,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'ControlNetModel', type: LIST_TAG },
{ type: 'ControlNetModel', id: LIST_TAG },
];
if (result) {
@@ -215,7 +355,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'VaeModel', type: LIST_TAG },
{ type: 'VaeModel', id: LIST_TAG },
];
if (result) {
@@ -250,7 +390,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'TextualInversionModel', type: LIST_TAG },
{ type: 'TextualInversionModel', id: LIST_TAG },
];
if (result) {
@@ -278,6 +418,36 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
getModelsInFolder: build.query<SearchFolderResponse, SearchFolderArg>({
query: (arg) => {
const folderQueryStr = queryString.stringify(arg, {});
return {
url: `/models/search?${folderQueryStr}`,
};
},
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ type: 'ScannedModels', id: LIST_TAG },
];
if (result) {
tags.push(
...result.map((id) => ({
type: 'ScannedModels' as const,
id,
}))
);
}
return tags;
},
}),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => {
return {
url: `/models/ckpt_confs`,
};
},
}),
}),
});
@@ -288,4 +458,12 @@ export const {
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useUpdateMainModelsMutation,
useDeleteMainModelsMutation,
useImportMainModelsMutation,
useAddMainModelsMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery,
} = modelsApi;

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,9 @@
import { createAppAsyncThunk } from 'app/store/storeUtils';
import { selectImagesAll } from 'features/gallery/store/gallerySlice';
import { selectFilteredImages } from 'features/gallery/store/gallerySelectors';
import {
ASSETS_CATEGORIES,
IMAGE_CATEGORIES,
} from 'features/gallery/store/gallerySlice';
import { size } from 'lodash-es';
import queryString from 'query-string';
import { $client } from 'services/api/client';
@@ -287,15 +291,12 @@ export const receivedPageOfImages = createAppAsyncThunk<
const { get } = $client.get();
const state = getState();
const { categories, selectedBoardId } = state.gallery;
const images = selectImagesAll(state).filter((i) => {
const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = selectedBoardId
? i.board_id === selectedBoardId
: true;
return isInCategory && isInSelectedBoard;
});
const images = selectFilteredImages(state);
const categories =
state.gallery.galleryView === 'images'
? IMAGE_CATEGORIES
: ASSETS_CATEGORIES;
let query: ListImagesArg = {};

View File

@@ -28,11 +28,14 @@ export type OffsetPaginatedResults_ImageDTO_ =
// Models
export type ModelType = components['schemas']['ModelType'];
export type SubModelType = components['schemas']['SubModelType'];
export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField'];
export type OnnxModelField = components['schemas']['OnnxModelField'];
export type VAEModelField = components['schemas']['VAEModelField'];
export type LoRAModelField = components['schemas']['LoRAModelField'];
export type ControlNetModelField =
components['schemas']['ControlNetModelField'];
export type ModelsList = components['schemas']['ModelsList'];
export type ControlField = components['schemas']['ControlField'];
@@ -43,18 +46,28 @@ export type ControlNetModelConfig =
components['schemas']['ControlNetModelConfig'];
export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig'];
export type MainModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
export type DiffusersModelConfig =
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig']
| components['schemas']['StableDiffusionXLModelDiffusersConfig'];
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
export type OnnxModelConfig = components['schemas']['ONNXStableDiffusion1ModelConfig']
| components['schemas']['StableDiffusionXLModelCheckpointConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| LoRAModelConfig
| VaeModelConfig
| ControlNetModelConfig
| TextualInversionModelConfig
| MainModelConfig;
| MainModelConfig
| OnnxModelConfig;
export type MergeModelConfig = components['schemas']['Body_merge_models'];
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
export type ImportModelConfig = components['schemas']['Body_import_model'];
// Graphs
export type Graph = components['schemas']['Graph'];
@@ -81,6 +94,9 @@ export type InpaintInvocation = TypeReq<
export type ImageResizeInvocation = TypeReq<
components['schemas']['ImageResizeInvocation']
>;
export type ImageScaleInvocation = TypeReq<
components['schemas']['ImageScaleInvocation']
>;
export type RandomIntInvocation = TypeReq<
components['schemas']['RandomIntInvocation']
>;
@@ -118,6 +134,12 @@ export type LoraLoaderInvocation = TypeReq<
export type MetadataAccumulatorInvocation = TypeReq<
components['schemas']['MetadataAccumulatorInvocation']
>;
export type ESRGANInvocation = TypeReq<
components['schemas']['ESRGANInvocation']
>;
export type DivideInvocation = TypeReq<
components['schemas']['DivideInvocation']
>;
// ControlNet Nodes
export type ControlNetInvocation = TypeReq<