feat: workflow saving and loading

This commit is contained in:
psychedelicious
2023-08-24 21:42:32 +10:00
parent 7f6fdf5d39
commit 7d1942e9f0
51 changed files with 1175 additions and 320 deletions

View File

@@ -29,13 +29,15 @@ export const $projectId = atom<string | undefined>();
* @example
* const { get, post, del } = $client.get();
*/
export const $client = computed([$authToken, $baseUrl, $projectId], (authToken, baseUrl, projectId) =>
createClient<paths>({
headers: {
...(authToken ? { Authorization: `Bearer ${authToken}` } : {}),
...(projectId ? { "project-id": projectId } : {})
},
// do not include `api/v1` in the base url for this client
baseUrl: `${baseUrl ?? ''}`,
})
export const $client = computed(
[$authToken, $baseUrl, $projectId],
(authToken, baseUrl, projectId) =>
createClient<paths>({
headers: {
...(authToken ? { Authorization: `Bearer ${authToken}` } : {}),
...(projectId ? { 'project-id': projectId } : {}),
},
// do not include `api/v1` in the base url for this client
baseUrl: `${baseUrl ?? ''}`,
})
);

View File

@@ -19,7 +19,7 @@ export const boardsApi = api.injectEndpoints({
*/
listBoards: build.query<OffsetPaginatedResults_BoardDTO_, ListBoardsArg>({
query: (arg) => ({ url: 'boards/', params: arg }),
providesTags: (result, error, arg) => {
providesTags: (result) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
@@ -42,7 +42,7 @@ export const boardsApi = api.injectEndpoints({
url: 'boards/',
params: { all: true },
}),
providesTags: (result, error, arg) => {
providesTags: (result) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];

View File

@@ -6,7 +6,8 @@ import {
IMAGE_CATEGORIES,
IMAGE_LIMIT,
} from 'features/gallery/store/types';
import { keyBy } from 'lodash';
import { getMetadataAndWorkflowFromImageBlob } from 'features/nodes/util/getMetadataAndWorkflowFromImageBlob';
import { keyBy } from 'lodash-es';
import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { components, paths } from '../schema';
import {
@@ -26,6 +27,7 @@ import {
imagesSelectors,
} from '../util';
import { boardsApi } from './boards';
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
export const imagesApi = api.injectEndpoints({
endpoints: (build) => ({
@@ -113,6 +115,19 @@ export const imagesApi = api.injectEndpoints({
],
keepUnusedDataFor: 86400, // 24 hours
}),
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, string>({
query: (image_name) => ({
url: `images/i/${image_name}/full`,
responseHandler: async (res) => {
return await res.blob();
},
}),
providesTags: (result, error, image_name) => [
{ type: 'ImageMetadataFromFile', id: image_name },
],
transformResponse: (response: Blob) =>
getMetadataAndWorkflowFromImageBlob(response),
}),
clearIntermediates: build.mutation<number, void>({
query: () => ({ url: `images/clear-intermediates`, method: 'POST' }),
invalidatesTags: ['IntermediatesCount'],
@@ -357,7 +372,7 @@ export const imagesApi = api.injectEndpoints({
],
async onQueryStarted(
{ imageDTO, session_id },
{ dispatch, queryFulfilled, getState }
{ dispatch, queryFulfilled }
) {
/**
* Cache changes for `changeImageSessionId`:
@@ -432,7 +447,9 @@ export const imagesApi = api.injectEndpoints({
data.updated_image_names.includes(i.image_name)
);
if (!updatedImages[0]) return;
if (!updatedImages[0]) {
return;
}
// assume all images are on the same board/category
const categories = getCategories(updatedImages[0]);
@@ -544,7 +561,9 @@ export const imagesApi = api.injectEndpoints({
data.updated_image_names.includes(i.image_name)
);
if (!updatedImages[0]) return;
if (!updatedImages[0]) {
return;
}
// assume all images are on the same board/category
const categories = getCategories(updatedImages[0]);
const boardId = updatedImages[0].board_id;
@@ -645,17 +664,7 @@ export const imagesApi = api.injectEndpoints({
},
};
},
async onQueryStarted(
{
file,
image_category,
is_intermediate,
postUploadAction,
session_id,
board_id,
},
{ dispatch, queryFulfilled }
) {
async onQueryStarted(_, { dispatch, queryFulfilled }) {
try {
/**
* NOTE: PESSIMISTIC UPDATE
@@ -712,7 +721,7 @@ export const imagesApi = api.injectEndpoints({
deleteBoard: build.mutation<DeleteBoardResult, string>({
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }),
invalidatesTags: (result, error, board_id) => [
invalidatesTags: () => [
{ type: 'Board', id: LIST_TAG },
// invalidate the 'No Board' cache
{
@@ -732,7 +741,7 @@ export const imagesApi = api.injectEndpoints({
{ type: 'BoardImagesTotal', id: 'none' },
{ type: 'BoardAssetsTotal', id: 'none' },
],
async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) {
async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
/**
* Cache changes for deleteBoard:
* - Update every image in the 'getImageDTO' cache that has the board_id
@@ -802,7 +811,7 @@ export const imagesApi = api.injectEndpoints({
method: 'DELETE',
params: { include_images: true },
}),
invalidatesTags: (result, error, board_id) => [
invalidatesTags: () => [
{ type: 'Board', id: LIST_TAG },
{
type: 'ImageList',
@@ -821,7 +830,7 @@ export const imagesApi = api.injectEndpoints({
{ type: 'BoardImagesTotal', id: 'none' },
{ type: 'BoardAssetsTotal', id: 'none' },
],
async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) {
async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
/**
* Cache changes for deleteBoardAndImages:
* - ~~Remove every image in the 'getImageDTO' cache that has the board_id~~
@@ -1253,9 +1262,8 @@ export const imagesApi = api.injectEndpoints({
];
result?.removed_image_names.forEach((image_name) => {
const board_id = imageDTOs.find(
(i) => i.image_name === image_name
)?.board_id;
const board_id = imageDTOs.find((i) => i.image_name === image_name)
?.board_id;
if (!board_id || touchedBoardIds.includes(board_id)) {
return;
@@ -1385,4 +1393,5 @@ export const {
useDeleteBoardMutation,
useStarImagesMutation,
useUnstarImagesMutation,
useGetImageMetadataFromFileQuery,
} = imagesApi;

View File

@@ -178,7 +178,7 @@ export const modelsApi = api.injectEndpoints({
const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`;
},
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'OnnxModel', id: LIST_TAG },
];
@@ -194,11 +194,7 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: OnnxModelConfig[] },
meta,
arg
) => {
transformResponse: (response: { models: OnnxModelConfig[] }) => {
const entities = createModelEntities<OnnxModelConfigEntity>(
response.models
);
@@ -221,7 +217,7 @@ export const modelsApi = api.injectEndpoints({
const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`;
},
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'MainModel', id: LIST_TAG },
];
@@ -237,11 +233,7 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: MainModelConfig[] },
meta,
arg
) => {
transformResponse: (response: { models: MainModelConfig[] }) => {
const entities = createModelEntities<MainModelConfigEntity>(
response.models
);
@@ -361,7 +353,7 @@ export const modelsApi = api.injectEndpoints({
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'LoRAModel', id: LIST_TAG },
];
@@ -377,11 +369,7 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: LoRAModelConfig[] },
meta,
arg
) => {
transformResponse: (response: { models: LoRAModelConfig[] }) => {
const entities = createModelEntities<LoRAModelConfigEntity>(
response.models
);
@@ -421,7 +409,7 @@ export const modelsApi = api.injectEndpoints({
void
>({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'ControlNetModel', id: LIST_TAG },
];
@@ -437,11 +425,7 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: ControlNetModelConfig[] },
meta,
arg
) => {
transformResponse: (response: { models: ControlNetModelConfig[] }) => {
const entities = createModelEntities<ControlNetModelConfigEntity>(
response.models
);
@@ -453,7 +437,7 @@ export const modelsApi = api.injectEndpoints({
}),
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'VaeModel', id: LIST_TAG },
];
@@ -469,11 +453,7 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: VaeModelConfig[] },
meta,
arg
) => {
transformResponse: (response: { models: VaeModelConfig[] }) => {
const entities = createModelEntities<VaeModelConfigEntity>(
response.models
);
@@ -488,7 +468,7 @@ export const modelsApi = api.injectEndpoints({
void
>({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'TextualInversionModel', id: LIST_TAG },
];
@@ -504,11 +484,9 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: TextualInversionModelConfig[] },
meta,
arg
) => {
transformResponse: (response: {
models: TextualInversionModelConfig[];
}) => {
const entities = createModelEntities<TextualInversionModelConfigEntity>(
response.models
);
@@ -525,7 +503,7 @@ export const modelsApi = api.injectEndpoints({
url: `/models/search?${folderQueryStr}`,
};
},
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'ScannedModels', id: LIST_TAG },
];

View File

@@ -16,6 +16,7 @@ export const tagTypes = [
'ImageNameList',
'ImageList',
'ImageMetadata',
'ImageMetadataFromFile',
'Model',
];
export type ApiFullTagDescription = FullTagDescription<
@@ -39,7 +40,7 @@ const dynamicBaseQuery: BaseQueryFn<
headers.set('Authorization', `Bearer ${authToken}`);
}
if (projectId) {
headers.set("project-id", projectId)
headers.set('project-id', projectId);
}
return headers;

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,16 @@
import { createAsyncThunk } from '@reduxjs/toolkit';
function getCircularReplacer() {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const ancestors: Record<string, any>[] = [];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return function (key: string, value: any) {
if (typeof value !== 'object' || value === null) {
return value;
}
// `this` is the object that value is contained in,
// i.e., its direct parent.
// @ts-ignore
// `this` is the object that value is contained in, i.e., its direct parent.
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore don't think it's possible to not have TS complain about this...
while (ancestors.length > 0 && ancestors.at(-1) !== this) {
ancestors.pop();
}

View File

@@ -73,7 +73,7 @@ export const sessionInvoked = createAsyncThunk<
>('api/sessionInvoked', async (arg, { rejectWithValue }) => {
const { session_id } = arg;
const { PUT } = $client.get();
const { data, error, response } = await PUT(
const { error, response } = await PUT(
'/api/v1/sessions/{session_id}/invoke',
{
params: { query: { all: true }, path: { session_id } },
@@ -85,6 +85,7 @@ export const sessionInvoked = createAsyncThunk<
return rejectWithValue({
arg,
status: response.status,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
error: (error as any).body.detail,
});
}
@@ -124,14 +125,11 @@ export const sessionCanceled = createAsyncThunk<
>('api/sessionCanceled', async (arg, { rejectWithValue }) => {
const { session_id } = arg;
const { DELETE } = $client.get();
const { data, error, response } = await DELETE(
'/api/v1/sessions/{session_id}/invoke',
{
params: {
path: { session_id },
},
}
);
const { data, error } = await DELETE('/api/v1/sessions/{session_id}/invoke', {
params: {
path: { session_id },
},
});
if (error) {
return rejectWithValue({ arg, error });
@@ -164,7 +162,7 @@ export const listedSessions = createAsyncThunk<
>('api/listSessions', async (arg, { rejectWithValue }) => {
const { params } = arg;
const { GET } = $client.get();
const { data, error, response } = await GET('/api/v1/sessions/', {
const { data, error } = await GET('/api/v1/sessions/', {
params,
});

View File

@@ -26,15 +26,21 @@ export const getIsImageInDateRange = (
for (let index = 0; index < totalCachedImageDtos.length; index++) {
const image = totalCachedImageDtos[index];
if (image?.starred) cachedStarredImages.push(image);
if (!image?.starred) cachedUnstarredImages.push(image);
if (image?.starred) {
cachedStarredImages.push(image);
}
if (!image?.starred) {
cachedUnstarredImages.push(image);
}
}
if (imageDTO.starred) {
const lastStarredImage =
cachedStarredImages[cachedStarredImages.length - 1];
// if starring or already starred, want to look in list of starred images
if (!lastStarredImage) return true; // no starred images showing, so always show this one
if (!lastStarredImage) {
return true;
} // no starred images showing, so always show this one
const createdDate = new Date(imageDTO.created_at);
const oldestDate = new Date(lastStarredImage.created_at);
return createdDate >= oldestDate;
@@ -42,7 +48,9 @@ export const getIsImageInDateRange = (
const lastUnstarredImage =
cachedUnstarredImages[cachedUnstarredImages.length - 1];
// if unstarring or already unstarred, want to look in list of unstarred images
if (!lastUnstarredImage) return false; // no unstarred images showing, so don't show this one
if (!lastUnstarredImage) {
return false;
} // no unstarred images showing, so don't show this one
const createdDate = new Date(imageDTO.created_at);
const oldestDate = new Date(lastUnstarredImage.created_at);
return createdDate >= oldestDate;